Use the new MLX kernels to handle the BF16 matmul. (#2470)

This commit is contained in:
Laurent Mazare
2024-09-11 16:34:05 +01:00
committed by GitHub
parent 5635650d38
commit afb6575835
2 changed files with 46 additions and 26 deletions

View File

@ -173,8 +173,8 @@ impl Device {
pub fn supports_bf16(&self) -> bool { pub fn supports_bf16(&self) -> bool {
match self { match self {
Self::Cuda(_) => true, Self::Cuda(_) | Self::Metal(_) => true,
Self::Metal(_) | Self::Cpu => false, Self::Cpu => false,
} }
} }

View File

@ -1398,6 +1398,7 @@ impl BackendStorage for MetalStorage {
.map_err(MetalError::from)?; .map_err(MetalError::from)?;
Ok(acc) Ok(acc)
} }
fn matmul( fn matmul(
&self, &self,
rhs: &Self, rhs: &Self,
@ -1406,32 +1407,51 @@ impl BackendStorage for MetalStorage {
rhs_l: &Layout, rhs_l: &Layout,
) -> Result<Self> { ) -> Result<Self> {
let buffer = self.device.new_buffer(b * m * n, self.dtype, "matmul")?; let buffer = self.device.new_buffer(b * m * n, self.dtype, "matmul")?;
let name = match self.dtype {
DType::F32 => "sgemm",
DType::F16 => "hgemm",
DType::BF16 => "bgemm",
dtype => {
return Err(MetalError::Message(format!("matmul doesn't support {dtype:?}")).into())
}
};
let command_buffer = self.device.command_buffer()?; let command_buffer = self.device.command_buffer()?;
command_buffer.set_label("matmul"); command_buffer.set_label("matmul");
candle_metal_kernels::call_gemm( if self.dtype == DType::BF16 {
&self.device.device, candle_metal_kernels::call_mlx_gemm(
&command_buffer, &self.device.device,
&self.device.kernels, &command_buffer,
name, &self.device.kernels,
(b, m, n, k), candle_metal_kernels::GemmDType::BF16,
lhs_l.stride(), (b, m, n, k),
lhs_l.start_offset() * self.dtype.size_in_bytes(), lhs_l.stride(),
&self.buffer, lhs_l.start_offset() * self.dtype.size_in_bytes(),
rhs_l.stride(), &self.buffer,
rhs_l.start_offset() * rhs.dtype.size_in_bytes(), rhs_l.stride(),
&rhs.buffer, rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
&buffer, &rhs.buffer,
) &buffer,
.map_err(MetalError::from)?; )
.map_err(MetalError::from)?;
} else {
let name = match self.dtype {
DType::F32 => "sgemm",
DType::F16 => "hgemm",
dtype => {
return Err(
MetalError::Message(format!("matmul doesn't support {dtype:?}")).into(),
)
}
};
candle_metal_kernels::call_gemm(
&self.device.device,
&command_buffer,
&self.device.kernels,
name,
(b, m, n, k),
lhs_l.stride(),
lhs_l.start_offset() * self.dtype.size_in_bytes(),
&self.buffer,
rhs_l.stride(),
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
&rhs.buffer,
&buffer,
)
.map_err(MetalError::from)?;
}
Ok(Self::new( Ok(Self::new(
buffer, buffer,
self.device.clone(), self.device.clone(),