diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs index cdbeb65d..2bb07ea4 100644 --- a/candle-core/src/metal_backend/mod.rs +++ b/candle-core/src/metal_backend/mod.rs @@ -1655,50 +1655,32 @@ impl BackendStorage for MetalStorage { let buffer = self.device.new_buffer(b * m * n, self.dtype, "matmul")?; let command_buffer = self.device.command_buffer()?; command_buffer.set_label("matmul"); - if self.dtype == DType::BF16 { - candle_metal_kernels::call_mlx_gemm( - &self.device.device, - &command_buffer, - &self.device.kernels, - candle_metal_kernels::GemmDType::BF16, - (b, m, n, k), - lhs_l.stride(), - lhs_l.start_offset() * self.dtype.size_in_bytes(), - &self.buffer, - rhs_l.stride(), - rhs_l.start_offset() * rhs.dtype.size_in_bytes(), - &rhs.buffer, - &buffer, - ) - .map_err(MetalError::from)?; - } else { - let dtype = match self.dtype { - DType::F32 => candle_metal_kernels::GemmDType::F32, - DType::F16 => candle_metal_kernels::GemmDType::F16, - DType::BF16 => candle_metal_kernels::GemmDType::BF16, - dtype => { - return Err(MetalError::Message(format!( - "mlx matmul doesn't support {dtype:?}" - )) - .into()) - } - }; - candle_metal_kernels::call_mlx_gemm( - &self.device.device, - &command_buffer, - &self.device.kernels, - dtype, - (b, m, n, k), - lhs_l.stride(), - lhs_l.start_offset() * self.dtype.size_in_bytes(), - &self.buffer, - rhs_l.stride(), - rhs_l.start_offset() * rhs.dtype.size_in_bytes(), - &rhs.buffer, - &buffer, - ) - .map_err(MetalError::from)?; - } + let dtype = match self.dtype { + DType::F32 => candle_metal_kernels::GemmDType::F32, + DType::F16 => candle_metal_kernels::GemmDType::F16, + DType::BF16 => candle_metal_kernels::GemmDType::BF16, + dtype => { + return Err( + MetalError::Message(format!("mlx matmul doesn't support {dtype:?}")).into(), + ) + } + }; + candle_metal_kernels::call_mlx_gemm( + &self.device.device, + &command_buffer, + &self.device.kernels, + dtype, + (b, m, n, k), + lhs_l.stride(), + lhs_l.start_offset() * self.dtype.size_in_bytes(), + &self.buffer, + rhs_l.stride(), + rhs_l.start_offset() * rhs.dtype.size_in_bytes(), + &rhs.buffer, + &buffer, + ) + .map_err(MetalError::from)?; + Ok(Self::new( buffer, self.device.clone(),