mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
Use the new MLX kernels to handle the BF16 matmul. (#2470)
This commit is contained in:
@ -173,8 +173,8 @@ impl Device {
|
|||||||
|
|
||||||
pub fn supports_bf16(&self) -> bool {
|
pub fn supports_bf16(&self) -> bool {
|
||||||
match self {
|
match self {
|
||||||
Self::Cuda(_) => true,
|
Self::Cuda(_) | Self::Metal(_) => true,
|
||||||
Self::Metal(_) | Self::Cpu => false,
|
Self::Cpu => false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1398,6 +1398,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
Ok(acc)
|
Ok(acc)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn matmul(
|
fn matmul(
|
||||||
&self,
|
&self,
|
||||||
rhs: &Self,
|
rhs: &Self,
|
||||||
@ -1406,17 +1407,35 @@ impl BackendStorage for MetalStorage {
|
|||||||
rhs_l: &Layout,
|
rhs_l: &Layout,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let buffer = self.device.new_buffer(b * m * n, self.dtype, "matmul")?;
|
let buffer = self.device.new_buffer(b * m * n, self.dtype, "matmul")?;
|
||||||
|
let command_buffer = self.device.command_buffer()?;
|
||||||
|
command_buffer.set_label("matmul");
|
||||||
|
if self.dtype == DType::BF16 {
|
||||||
|
candle_metal_kernels::call_mlx_gemm(
|
||||||
|
&self.device.device,
|
||||||
|
&command_buffer,
|
||||||
|
&self.device.kernels,
|
||||||
|
candle_metal_kernels::GemmDType::BF16,
|
||||||
|
(b, m, n, k),
|
||||||
|
lhs_l.stride(),
|
||||||
|
lhs_l.start_offset() * self.dtype.size_in_bytes(),
|
||||||
|
&self.buffer,
|
||||||
|
rhs_l.stride(),
|
||||||
|
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
|
||||||
|
&rhs.buffer,
|
||||||
|
&buffer,
|
||||||
|
)
|
||||||
|
.map_err(MetalError::from)?;
|
||||||
|
} else {
|
||||||
let name = match self.dtype {
|
let name = match self.dtype {
|
||||||
DType::F32 => "sgemm",
|
DType::F32 => "sgemm",
|
||||||
DType::F16 => "hgemm",
|
DType::F16 => "hgemm",
|
||||||
DType::BF16 => "bgemm",
|
|
||||||
dtype => {
|
dtype => {
|
||||||
return Err(MetalError::Message(format!("matmul doesn't support {dtype:?}")).into())
|
return Err(
|
||||||
|
MetalError::Message(format!("matmul doesn't support {dtype:?}")).into(),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let command_buffer = self.device.command_buffer()?;
|
|
||||||
command_buffer.set_label("matmul");
|
|
||||||
candle_metal_kernels::call_gemm(
|
candle_metal_kernels::call_gemm(
|
||||||
&self.device.device,
|
&self.device.device,
|
||||||
&command_buffer,
|
&command_buffer,
|
||||||
@ -1432,6 +1451,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
&buffer,
|
&buffer,
|
||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
|
}
|
||||||
Ok(Self::new(
|
Ok(Self::new(
|
||||||
buffer,
|
buffer,
|
||||||
self.device.clone(),
|
self.device.clone(),
|
||||||
|
Reference in New Issue
Block a user