RmsNorm kernel for metal. (#1895)

* RmsNorm kernel for metal.

* Wrapper for the metal kernel.

* Get the ops to actually work.

* Fix, get the tests to pass.
This commit is contained in:
Laurent Mazare
2024-03-21 09:48:56 +01:00
committed by GitHub
parent 74b7f59261
commit 0fddec762e
3 changed files with 160 additions and 1 deletions

View File

@ -236,7 +236,7 @@ impl candle::CustomOp1 for SoftmaxLastDim {
layout.start_offset() * storage.dtype().size_in_bytes(),
&output,
)
.unwrap();
.map_err(candle::Error::wrap)?;
let newstorage =
candle::MetalStorage::new(output, device.clone(), elem_count, storage.dtype());
Ok((newstorage, layout.shape().clone()))
@ -383,6 +383,51 @@ impl candle::CustomOp2 for RmsNorm {
};
Ok((dst, l1.shape().clone()))
}
#[cfg(feature = "metal")]
fn metal_fwd(
&self,
s1: &candle::MetalStorage,
l1: &Layout,
s2: &candle::MetalStorage,
l2: &Layout,
) -> Result<(candle::MetalStorage, Shape)> {
use candle::backend::BackendStorage;
let device = s1.device();
let command_buffer = device.command_buffer()?;
let kernels = device.kernels();
let name = match (s1.dtype(), s2.dtype()) {
(DType::F32, DType::F32) => "rmsnorm_f32",
(DType::F16, DType::F16) => "rmsnorm_f16",
(DType::BF16, DType::BF16) => "rmsnorm_bf16",
(dt1, dt2) => candle::bail!("rmsnorm is not implemented for {dt1:?} {dt2:?}"),
};
if !(l1.is_contiguous() && l2.is_contiguous()) {
candle::bail!("Non contiguous rmsnorm is not implemented");
}
let last_dim = l1.dims()[l1.shape().rank() - 1];
let elem_count = l1.shape().elem_count();
let output = device.new_buffer(elem_count, s1.dtype(), "rmsnorm")?;
candle_metal_kernels::call_rms_norm(
device.metal_device(),
&command_buffer,
kernels,
name,
elem_count,
last_dim,
self.eps,
s1.buffer(),
l1.start_offset() * s1.dtype().size_in_bytes(),
s2.buffer(),
l2.start_offset() * s2.dtype().size_in_bytes(),
&output,
)
.map_err(candle::Error::wrap)?;
let newstorage = candle::MetalStorage::new(output, device.clone(), elem_count, s1.dtype());
Ok((newstorage, l1.shape().clone()))
}
}
pub fn rms_norm_slow(x: &Tensor, alpha: &Tensor, eps: f32) -> Result<Tensor> {