From 72d649058b23b9cb2e2140194c4040f1e56d938e Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Thu, 12 Sep 2024 12:52:59 +0100 Subject: [PATCH] Hook the MLX matmul kernels in candle-core. (#2473) --- candle-core/src/metal_backend/device.rs | 6 +++++ candle-core/src/metal_backend/mod.rs | 32 +++++++++++++++++++++++++ 2 files changed, 38 insertions(+) diff --git a/candle-core/src/metal_backend/device.rs b/candle-core/src/metal_backend/device.rs index 07210c68..3deb465b 100644 --- a/candle-core/src/metal_backend/device.rs +++ b/candle-core/src/metal_backend/device.rs @@ -70,6 +70,8 @@ pub struct MetalDevice { pub(crate) buffers: AllocatedBuffers, /// Seed for random number generation. pub(crate) seed: Arc>, + /// Whether to use the MLX matmul kernels instead of the MFA ones. + pub(crate) use_mlx_mm: bool, } impl std::fmt::Debug for MetalDevice { @@ -87,6 +89,10 @@ impl std::ops::Deref for MetalDevice { } impl MetalDevice { + pub fn set_use_mlx_mm(&mut self, use_mlx_mm: bool) { + self.use_mlx_mm = use_mlx_mm + } + pub fn id(&self) -> DeviceId { self.id } diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs index 19557cf2..9845a42f 100644 --- a/candle-core/src/metal_backend/mod.rs +++ b/candle-core/src/metal_backend/mod.rs @@ -1425,6 +1425,33 @@ impl BackendStorage for MetalStorage { &buffer, ) .map_err(MetalError::from)?; + } else if self.device.use_mlx_mm { + let dtype = match self.dtype { + DType::F32 => candle_metal_kernels::GemmDType::F32, + DType::F16 => candle_metal_kernels::GemmDType::F16, + DType::BF16 => candle_metal_kernels::GemmDType::BF16, + dtype => { + return Err(MetalError::Message(format!( + "mlx matmul doesn't support {dtype:?}" + )) + .into()) + } + }; + candle_metal_kernels::call_mlx_gemm( + &self.device.device, + &command_buffer, + &self.device.kernels, + dtype, + (b, m, n, k), + lhs_l.stride(), + lhs_l.start_offset() * self.dtype.size_in_bytes(), + &self.buffer, + rhs_l.stride(), + rhs_l.start_offset() * rhs.dtype.size_in_bytes(), + &rhs.buffer, + &buffer, + ) + .map_err(MetalError::from)?; } else { let name = match self.dtype { DType::F32 => "sgemm", @@ -1818,6 +1845,10 @@ impl BackendDevice for MetalDevice { let command_buffer_index = Arc::new(RwLock::new(0)); let kernels = Arc::new(Kernels::new()); let buffers = Arc::new(RwLock::new(HashMap::new())); + let use_mlx_mm = match std::env::var("CANDLE_USE_MLX_MM").as_deref() { + Ok("false") | Ok("False") | Ok("FALSE") | Ok("0") | Err(_) => false, + Ok(_) => true, + }; let compute_per_buffer = match std::env::var("CANDLE_METAL_COMPUTE_PER_BUFFER") { Ok(val) => val.parse()?, _ => 50, @@ -1837,6 +1868,7 @@ impl BackendDevice for MetalDevice { buffers, kernels, seed, + use_mlx_mm, }) }