mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Use the new MLX kernels to handle the BF16 matmul. (#2470)
This commit is contained in:
@ -173,8 +173,8 @@ impl Device {
|
|||||||
|
|
||||||
pub fn supports_bf16(&self) -> bool {
|
pub fn supports_bf16(&self) -> bool {
|
||||||
match self {
|
match self {
|
||||||
Self::Cuda(_) => true,
|
Self::Cuda(_) | Self::Metal(_) => true,
|
||||||
Self::Metal(_) | Self::Cpu => false,
|
Self::Cpu => false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1398,6 +1398,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
Ok(acc)
|
Ok(acc)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn matmul(
|
fn matmul(
|
||||||
&self,
|
&self,
|
||||||
rhs: &Self,
|
rhs: &Self,
|
||||||
@ -1406,32 +1407,51 @@ impl BackendStorage for MetalStorage {
|
|||||||
rhs_l: &Layout,
|
rhs_l: &Layout,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let buffer = self.device.new_buffer(b * m * n, self.dtype, "matmul")?;
|
let buffer = self.device.new_buffer(b * m * n, self.dtype, "matmul")?;
|
||||||
let name = match self.dtype {
|
|
||||||
DType::F32 => "sgemm",
|
|
||||||
DType::F16 => "hgemm",
|
|
||||||
DType::BF16 => "bgemm",
|
|
||||||
dtype => {
|
|
||||||
return Err(MetalError::Message(format!("matmul doesn't support {dtype:?}")).into())
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let command_buffer = self.device.command_buffer()?;
|
let command_buffer = self.device.command_buffer()?;
|
||||||
command_buffer.set_label("matmul");
|
command_buffer.set_label("matmul");
|
||||||
candle_metal_kernels::call_gemm(
|
if self.dtype == DType::BF16 {
|
||||||
&self.device.device,
|
candle_metal_kernels::call_mlx_gemm(
|
||||||
&command_buffer,
|
&self.device.device,
|
||||||
&self.device.kernels,
|
&command_buffer,
|
||||||
name,
|
&self.device.kernels,
|
||||||
(b, m, n, k),
|
candle_metal_kernels::GemmDType::BF16,
|
||||||
lhs_l.stride(),
|
(b, m, n, k),
|
||||||
lhs_l.start_offset() * self.dtype.size_in_bytes(),
|
lhs_l.stride(),
|
||||||
&self.buffer,
|
lhs_l.start_offset() * self.dtype.size_in_bytes(),
|
||||||
rhs_l.stride(),
|
&self.buffer,
|
||||||
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
|
rhs_l.stride(),
|
||||||
&rhs.buffer,
|
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
|
||||||
&buffer,
|
&rhs.buffer,
|
||||||
)
|
&buffer,
|
||||||
.map_err(MetalError::from)?;
|
)
|
||||||
|
.map_err(MetalError::from)?;
|
||||||
|
} else {
|
||||||
|
let name = match self.dtype {
|
||||||
|
DType::F32 => "sgemm",
|
||||||
|
DType::F16 => "hgemm",
|
||||||
|
dtype => {
|
||||||
|
return Err(
|
||||||
|
MetalError::Message(format!("matmul doesn't support {dtype:?}")).into(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
candle_metal_kernels::call_gemm(
|
||||||
|
&self.device.device,
|
||||||
|
&command_buffer,
|
||||||
|
&self.device.kernels,
|
||||||
|
name,
|
||||||
|
(b, m, n, k),
|
||||||
|
lhs_l.stride(),
|
||||||
|
lhs_l.start_offset() * self.dtype.size_in_bytes(),
|
||||||
|
&self.buffer,
|
||||||
|
rhs_l.stride(),
|
||||||
|
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
|
||||||
|
&rhs.buffer,
|
||||||
|
&buffer,
|
||||||
|
)
|
||||||
|
.map_err(MetalError::from)?;
|
||||||
|
}
|
||||||
Ok(Self::new(
|
Ok(Self::new(
|
||||||
buffer,
|
buffer,
|
||||||
self.device.clone(),
|
self.device.clone(),
|
||||||
|
Reference in New Issue
Block a user