Remove redundant mlx gemm dtype check (#2925)

This commit is contained in:
ivarflakstad
2025-04-27 06:14:57 +02:00
committed by GitHub
parent fbaf0b0e32
commit 6e0646c208

View File

@ -1655,50 +1655,32 @@ impl BackendStorage for MetalStorage {
let buffer = self.device.new_buffer(b * m * n, self.dtype, "matmul")?; let buffer = self.device.new_buffer(b * m * n, self.dtype, "matmul")?;
let command_buffer = self.device.command_buffer()?; let command_buffer = self.device.command_buffer()?;
command_buffer.set_label("matmul"); command_buffer.set_label("matmul");
if self.dtype == DType::BF16 { let dtype = match self.dtype {
candle_metal_kernels::call_mlx_gemm( DType::F32 => candle_metal_kernels::GemmDType::F32,
&self.device.device, DType::F16 => candle_metal_kernels::GemmDType::F16,
&command_buffer, DType::BF16 => candle_metal_kernels::GemmDType::BF16,
&self.device.kernels, dtype => {
candle_metal_kernels::GemmDType::BF16, return Err(
(b, m, n, k), MetalError::Message(format!("mlx matmul doesn't support {dtype:?}")).into(),
lhs_l.stride(), )
lhs_l.start_offset() * self.dtype.size_in_bytes(), }
&self.buffer, };
rhs_l.stride(), candle_metal_kernels::call_mlx_gemm(
rhs_l.start_offset() * rhs.dtype.size_in_bytes(), &self.device.device,
&rhs.buffer, &command_buffer,
&buffer, &self.device.kernels,
) dtype,
.map_err(MetalError::from)?; (b, m, n, k),
} else { lhs_l.stride(),
let dtype = match self.dtype { lhs_l.start_offset() * self.dtype.size_in_bytes(),
DType::F32 => candle_metal_kernels::GemmDType::F32, &self.buffer,
DType::F16 => candle_metal_kernels::GemmDType::F16, rhs_l.stride(),
DType::BF16 => candle_metal_kernels::GemmDType::BF16, rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
dtype => { &rhs.buffer,
return Err(MetalError::Message(format!( &buffer,
"mlx matmul doesn't support {dtype:?}" )
)) .map_err(MetalError::from)?;
.into())
}
};
candle_metal_kernels::call_mlx_gemm(
&self.device.device,
&command_buffer,
&self.device.kernels,
dtype,
(b, m, n, k),
lhs_l.stride(),
lhs_l.start_offset() * self.dtype.size_in_bytes(),
&self.buffer,
rhs_l.stride(),
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
&rhs.buffer,
&buffer,
)
.map_err(MetalError::from)?;
}
Ok(Self::new( Ok(Self::new(
buffer, buffer,
self.device.clone(), self.device.clone(),