mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
RmsNorm kernel for metal. (#1895)
* RmsNorm kernel for metal. * Wrapper for the metal kernel. * Get the ops to actually work. * Fix, get the tests to pass.
This commit is contained in:
@ -236,7 +236,7 @@ impl candle::CustomOp1 for SoftmaxLastDim {
|
||||
layout.start_offset() * storage.dtype().size_in_bytes(),
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
.map_err(candle::Error::wrap)?;
|
||||
let newstorage =
|
||||
candle::MetalStorage::new(output, device.clone(), elem_count, storage.dtype());
|
||||
Ok((newstorage, layout.shape().clone()))
|
||||
@ -383,6 +383,51 @@ impl candle::CustomOp2 for RmsNorm {
|
||||
};
|
||||
Ok((dst, l1.shape().clone()))
|
||||
}
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
fn metal_fwd(
|
||||
&self,
|
||||
s1: &candle::MetalStorage,
|
||||
l1: &Layout,
|
||||
s2: &candle::MetalStorage,
|
||||
l2: &Layout,
|
||||
) -> Result<(candle::MetalStorage, Shape)> {
|
||||
use candle::backend::BackendStorage;
|
||||
let device = s1.device();
|
||||
let command_buffer = device.command_buffer()?;
|
||||
let kernels = device.kernels();
|
||||
let name = match (s1.dtype(), s2.dtype()) {
|
||||
(DType::F32, DType::F32) => "rmsnorm_f32",
|
||||
(DType::F16, DType::F16) => "rmsnorm_f16",
|
||||
(DType::BF16, DType::BF16) => "rmsnorm_bf16",
|
||||
(dt1, dt2) => candle::bail!("rmsnorm is not implemented for {dt1:?} {dt2:?}"),
|
||||
};
|
||||
|
||||
if !(l1.is_contiguous() && l2.is_contiguous()) {
|
||||
candle::bail!("Non contiguous rmsnorm is not implemented");
|
||||
}
|
||||
|
||||
let last_dim = l1.dims()[l1.shape().rank() - 1];
|
||||
let elem_count = l1.shape().elem_count();
|
||||
let output = device.new_buffer(elem_count, s1.dtype(), "rmsnorm")?;
|
||||
candle_metal_kernels::call_rms_norm(
|
||||
device.metal_device(),
|
||||
&command_buffer,
|
||||
kernels,
|
||||
name,
|
||||
elem_count,
|
||||
last_dim,
|
||||
self.eps,
|
||||
s1.buffer(),
|
||||
l1.start_offset() * s1.dtype().size_in_bytes(),
|
||||
s2.buffer(),
|
||||
l2.start_offset() * s2.dtype().size_in_bytes(),
|
||||
&output,
|
||||
)
|
||||
.map_err(candle::Error::wrap)?;
|
||||
let newstorage = candle::MetalStorage::new(output, device.clone(), elem_count, s1.dtype());
|
||||
Ok((newstorage, l1.shape().clone()))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn rms_norm_slow(x: &Tensor, alpha: &Tensor, eps: f32) -> Result<Tensor> {
|
||||
|
Reference in New Issue
Block a user