Use metal encode_gemm

This commit is contained in:
Ivar Flakstad
2023-11-06 03:27:22 +01:00
parent e6d33a8efb
commit 6d4c8c0707

View File

@ -1,15 +1,15 @@
use crate::backend::{BackendDevice, BackendStorage};
use crate::bail;
use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose2D};
use crate::error::Error;
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{CpuStorage, DType, Layout, Result, Shape};
use candle_metal_kernels;
use core::mem;
use half::{bf16, f16};
use metal;
use metal::mps::matrix::{Matrix, MatrixDescriptor, MatrixMultiplication};
use metal::mps::{Float32, MPSDataType};
use metal::{Buffer, MTLResourceOptions};
use metal::mps::matrix::encode_gemm;
use metal::mps::Float32;
use metal::{Buffer, MTLResourceOptions, NSUInteger};
use std::sync::Arc;
/// Metal related errors
@ -288,7 +288,7 @@ impl MetalStorage {
let elem_count = b * m * n;
match (self.dtype, rhs.dtype) {
(DType::F32, DType::F32) => {
let out_buffer = self.device.new_buffer(elem_count, self.dtype);
let mut out_buffer = self.device.new_buffer(elem_count, self.dtype);
if b != 1 {
println!("TODO implement batched matmul for B={b}");
// bail!("Didn't implemented strided matmul yet");
@ -310,54 +310,28 @@ impl MetalStorage {
dtype: self.dtype(),
});
}
let m: u64 = m.try_into().expect("usize should fit u64");
let n: u64 = n.try_into().expect("usize should fit u64");
let k: u64 = k.try_into().expect("usize should fit u64");
// Create descriptors
let left_descriptor =
MatrixDescriptor::init_single(m, k, k * Float32::SIZE, Float32::TYPE_ID);
let right_descriptor =
MatrixDescriptor::init_single(k, n, n * Float32::SIZE, Float32::TYPE_ID);
let result_descriptor =
MatrixDescriptor::init_single(m, n, n * Float32::SIZE, Float32::TYPE_ID);
encode_gemm::<Float32, Float32, Float32>(
&self.device,
&self.device.command_buffer,
transpose_left,
transpose_right,
&self.buffer,
&rhs.buffer,
&mut out_buffer,
m as NSUInteger,
n as NSUInteger,
k as NSUInteger,
alpha,
beta,
)
.map_err(|e| Error::Metal(e))?;
println!("lhs {:?} {m} {k}", self.buffer.length());
println!("rhs {:?} {k} {n}", rhs.buffer.length());
println!("out {:?} {m} {n}", out_buffer.length());
// Create matrix objects
let left_matrix =
Matrix::init_with_buffer_descriptor(&self.buffer, &left_descriptor)
.expect("Failed to create left matrix");
let right_matrix =
Matrix::init_with_buffer_descriptor(&rhs.buffer, &right_descriptor)
.expect("Failed to create left matrix");
let result_matrix =
Matrix::init_with_buffer_descriptor(&out_buffer, &result_descriptor)
.expect("Failed to create left matrix");
println!("lhs {:?}", lhs_l.shape());
// Create kernel
let matrix_multiplication = MatrixMultiplication::init(
&self.device,
transpose_left,
transpose_right,
m,
n,
k,
alpha,
beta,
)
.expect("Failed to create matrix multiplication kernel");
// Encode kernel to command buffer
matrix_multiplication.encode_to_command_buffer(
&self.device.command_buffer,
&left_matrix,
&right_matrix,
&result_matrix,
);
Ok(Self {
buffer: Arc::new(out_buffer),
device: self.device.clone(),