mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Remove redundant mlx gemm dtype check (#2925)
This commit is contained in:
@ -1655,50 +1655,32 @@ impl BackendStorage for MetalStorage {
|
|||||||
let buffer = self.device.new_buffer(b * m * n, self.dtype, "matmul")?;
|
let buffer = self.device.new_buffer(b * m * n, self.dtype, "matmul")?;
|
||||||
let command_buffer = self.device.command_buffer()?;
|
let command_buffer = self.device.command_buffer()?;
|
||||||
command_buffer.set_label("matmul");
|
command_buffer.set_label("matmul");
|
||||||
if self.dtype == DType::BF16 {
|
let dtype = match self.dtype {
|
||||||
candle_metal_kernels::call_mlx_gemm(
|
DType::F32 => candle_metal_kernels::GemmDType::F32,
|
||||||
&self.device.device,
|
DType::F16 => candle_metal_kernels::GemmDType::F16,
|
||||||
&command_buffer,
|
DType::BF16 => candle_metal_kernels::GemmDType::BF16,
|
||||||
&self.device.kernels,
|
dtype => {
|
||||||
candle_metal_kernels::GemmDType::BF16,
|
return Err(
|
||||||
(b, m, n, k),
|
MetalError::Message(format!("mlx matmul doesn't support {dtype:?}")).into(),
|
||||||
lhs_l.stride(),
|
)
|
||||||
lhs_l.start_offset() * self.dtype.size_in_bytes(),
|
}
|
||||||
&self.buffer,
|
};
|
||||||
rhs_l.stride(),
|
candle_metal_kernels::call_mlx_gemm(
|
||||||
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
|
&self.device.device,
|
||||||
&rhs.buffer,
|
&command_buffer,
|
||||||
&buffer,
|
&self.device.kernels,
|
||||||
)
|
dtype,
|
||||||
.map_err(MetalError::from)?;
|
(b, m, n, k),
|
||||||
} else {
|
lhs_l.stride(),
|
||||||
let dtype = match self.dtype {
|
lhs_l.start_offset() * self.dtype.size_in_bytes(),
|
||||||
DType::F32 => candle_metal_kernels::GemmDType::F32,
|
&self.buffer,
|
||||||
DType::F16 => candle_metal_kernels::GemmDType::F16,
|
rhs_l.stride(),
|
||||||
DType::BF16 => candle_metal_kernels::GemmDType::BF16,
|
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
|
||||||
dtype => {
|
&rhs.buffer,
|
||||||
return Err(MetalError::Message(format!(
|
&buffer,
|
||||||
"mlx matmul doesn't support {dtype:?}"
|
)
|
||||||
))
|
.map_err(MetalError::from)?;
|
||||||
.into())
|
|
||||||
}
|
|
||||||
};
|
|
||||||
candle_metal_kernels::call_mlx_gemm(
|
|
||||||
&self.device.device,
|
|
||||||
&command_buffer,
|
|
||||||
&self.device.kernels,
|
|
||||||
dtype,
|
|
||||||
(b, m, n, k),
|
|
||||||
lhs_l.stride(),
|
|
||||||
lhs_l.start_offset() * self.dtype.size_in_bytes(),
|
|
||||||
&self.buffer,
|
|
||||||
rhs_l.stride(),
|
|
||||||
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
|
|
||||||
&rhs.buffer,
|
|
||||||
&buffer,
|
|
||||||
)
|
|
||||||
.map_err(MetalError::from)?;
|
|
||||||
}
|
|
||||||
Ok(Self::new(
|
Ok(Self::new(
|
||||||
buffer,
|
buffer,
|
||||||
self.device.clone(),
|
self.device.clone(),
|
||||||
|
Reference in New Issue
Block a user