From cf27868b57f20cba869b4ca425b0b9c09c724822 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 15 Dec 2023 01:44:22 +0100 Subject: [PATCH] More cleanup. --- candle-metal-kernels/src/lib.rs | 31 ++++++++++++++++--------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 0c383dec..60f9b8a6 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -173,6 +173,12 @@ pub enum MetalKernelError { FailedToCreateComputeFunction, #[error("Failed to create pipeline")] FailedToCreatePipeline(String), + #[error("Invalid matmul arguments {lhs_stride:?} {rhs_stride:?} {mnk:?}")] + MatMulNonContiguous { + lhs_stride: Vec, + rhs_stride: Vec, + mnk: (usize, usize, usize), + }, } impl From> for MetalKernelError { @@ -1029,24 +1035,22 @@ pub fn call_gemm( } else if lhs_m1 == m && lhs_m2 == 1 { true } else { - todo!(); - // Err(MetalError::MatMulNonContiguous { - // lhs_stride: lhs_stride.to_vec(), - // rhs_stride: rhs_stride.to_vec(), - // mnk: (m, n, k), - // })? + return Err(MetalKernelError::MatMulNonContiguous { + lhs_stride: lhs_stride.to_vec(), + rhs_stride: rhs_stride.to_vec(), + mnk: (m, n, k), + })?; }; let b_trans = if rhs_m1 == 1 && rhs_m2 == n { false } else if rhs_m1 == k && rhs_m2 == 1 { true } else { - todo!(); - // Err(MetalError::MatMulNonContiguous { - // lhs_stride: lhs_stride.to_vec(), - // rhs_stride: rhs_stride.to_vec(), - // mnk: (m, n, k), - // })? + return Err(MetalKernelError::MatMulNonContiguous { + lhs_stride: lhs_stride.to_vec(), + rhs_stride: rhs_stride.to_vec(), + mnk: (m, n, k), + })?; }; let d_trans = false; let alpha = 1.0f32; @@ -1083,7 +1087,6 @@ pub fn call_gemm( (211, Value::U16(n_splits)), (50_001, Value::Bool(fused_bias)), ])); - // println!("Constants {constants:?}"); let pipeline = kernels.load_pipeline_with_constants(device, Source::Mfa, name, constants)?; let m_group = m_simd * m_splits; let n_group = n_simd * n_splits; @@ -1103,7 +1106,6 @@ pub fn call_gemm( block_elements = std::cmp::max(block_elements, n_group); } } - // TODO adapt for f16 let bytes = match name { "sgemm" => 4, "hgemm" => 2, @@ -1118,7 +1120,6 @@ pub fn call_gemm( let encoder = command_buffer.new_compute_command_encoder(); encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); - // println!("Threadgroup {block_bytes}"); encoder.set_threadgroup_memory_length(0, block_bytes.into()); encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger); encoder.set_buffer(1, Some(rhs_buffer), rhs_offset as NSUInteger);