Stash for debugging

This commit is contained in:
Ivar Flakstad
2023-12-10 13:11:53 +01:00
parent 35352e441a
commit ce0783d9ff
2 changed files with 285 additions and 80 deletions

View File

@ -795,14 +795,16 @@ impl BackendStorage for MetalStorage {
rhs_l: &Layout,
) -> Result<Self> {
// Create descriptors
let (type_id, size) = match self.dtype {
let (type_id, size, name) = match self.dtype {
DType::F32 => (
metal::mps::MPS_FLOATBIT_ENCODING | 32,
core::mem::size_of::<f32>() as NSUInteger,
"sgemm",
),
DType::F16 => (
metal::mps::MPS_FLOATBIT_ENCODING | 16,
core::mem::size_of::<f16>() as NSUInteger,
"hgemm",
),
dtype => todo!("Dtype for matmul {dtype:?} is not supported"),
};
@ -836,60 +838,37 @@ impl BackendStorage for MetalStorage {
mnk: (m, n, k),
})?
};
let b = b as NSUInteger;
let m = m as NSUInteger;
let n = n as NSUInteger;
let k = k as NSUInteger;
let left_matrix = self.matrix(
(b, m, k),
transpose_left,
size,
lhs_l.start_offset() as NSUInteger * size,
type_id,
)?;
let right_matrix = rhs.matrix(
(b, k, n),
transpose_right,
size,
rhs_l.start_offset() as NSUInteger * size,
type_id,
)?;
let (result_matrix, out_buffer) =
self.device
.new_matrix((b, m, n), size, type_id, self.dtype)?;
let result_buffer = self.device.new_buffer(b * m * n, self.dtype);
let command_buffer = self.device.command_buffer();
let alpha = 1.0f64;
let beta = 0.0f64;
// Create kernel
let matrix_multiplication = MatrixMultiplication::init(
&self.device,
command_buffer.set_label("mfa gemm");
candle_metal_kernels::call_mfa_gemm(
&self.device.device,
&command_buffer,
&self.device.kernels,
name,
&self.buffer,
lhs_l.shape().dims(),
&rhs.buffer,
rhs_l.shape().dims(),
&result_buffer,
(b, m, n, k),
transpose_left,
transpose_right,
m,
n,
k,
alpha,
beta,
)
.ok_or_else(|| {
MetalError::from("Failed to create matrix multiplication kernel".to_string())
})?;
.map_err(MetalError::from)?;
// Encode kernel to command buffer
matrix_multiplication.encode_to_command_buffer(
&command_buffer,
&left_matrix,
&right_matrix,
&result_matrix,
);
command_buffer.set_label("matmul");
drop(command_buffer);
self.device.commit();
Ok(Self::new(out_buffer, self.device.clone(), self.dtype()))
Ok(Self::new(
self.buffer.clone(),
self.device.clone(),
self.dtype(),
))
}
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {