mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 04:00:28 +00:00
Add the layernorm specialized op. (#2212)
* Add the layernorm cuda kernels. * Dedicated layer norm op. * Add the slower variant. * Plug the cuda implementation. * Add the metal variant. * Add a dedicated test. * Bugfix.
This commit is contained in:
@ -1,4 +1,4 @@
|
||||
use candle::{CpuStorage, DType, Layout, Result, Shape, Tensor};
|
||||
use candle::{CpuStorage, DType, Layout, Result, Shape, Tensor, D};
|
||||
use rayon::prelude::*;
|
||||
|
||||
/// Applies the softmax function to the input tensor, rescaling the element so that elements on
|
||||
@ -39,7 +39,7 @@ pub fn silu(xs: &Tensor) -> Result<Tensor> {
|
||||
}
|
||||
|
||||
pub fn swiglu(xs: &Tensor) -> Result<Tensor> {
|
||||
let xs = xs.chunk(2, candle::D::Minus1)?;
|
||||
let xs = xs.chunk(2, D::Minus1)?;
|
||||
&xs[0].silu()? * &xs[1]
|
||||
}
|
||||
|
||||
@ -620,15 +620,15 @@ pub fn rms_norm_slow(x: &Tensor, alpha: &Tensor, eps: f32) -> Result<Tensor> {
|
||||
DType::F16 | DType::BF16 => DType::F32,
|
||||
d => d,
|
||||
};
|
||||
let hidden_size = x.dim(candle::D::Minus1)?;
|
||||
let hidden_size = x.dim(D::Minus1)?;
|
||||
let x = x.to_dtype(internal_dtype)?;
|
||||
let norm_x = (x.sqr()?.sum_keepdim(candle::D::Minus1)? / hidden_size as f64)?;
|
||||
let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
|
||||
let x_normed = x.broadcast_div(&(norm_x + eps as f64)?.sqrt()?)?;
|
||||
x_normed.to_dtype(x_dtype)?.broadcast_mul(alpha)
|
||||
}
|
||||
|
||||
pub fn rms_norm(xs: &Tensor, alpha: &Tensor, eps: f32) -> Result<Tensor> {
|
||||
let hidden_size_xs = xs.dim(candle::D::Minus1)?;
|
||||
let hidden_size_xs = xs.dim(D::Minus1)?;
|
||||
let hidden_size_alpha = alpha.dims1()?;
|
||||
if hidden_size_xs != hidden_size_alpha {
|
||||
candle::bail!(
|
||||
@ -640,6 +640,254 @@ pub fn rms_norm(xs: &Tensor, alpha: &Tensor, eps: f32) -> Result<Tensor> {
|
||||
xs.apply_op2_no_bwd(alpha, &RmsNorm { eps })
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct LayerNorm {
|
||||
eps: f32,
|
||||
}
|
||||
|
||||
impl candle::CustomOp3 for LayerNorm {
|
||||
fn name(&self) -> &'static str {
|
||||
"layer-norm"
|
||||
}
|
||||
|
||||
fn cpu_fwd(
|
||||
&self,
|
||||
s1: &CpuStorage,
|
||||
l1: &Layout,
|
||||
s2: &CpuStorage,
|
||||
l2: &Layout,
|
||||
s3: &CpuStorage,
|
||||
l3: &Layout,
|
||||
) -> Result<(CpuStorage, Shape)> {
|
||||
use candle::backend::BackendStorage;
|
||||
|
||||
let eps = self.eps;
|
||||
fn inner<
|
||||
T: candle::WithDType
|
||||
+ num_traits::Float
|
||||
+ num_traits::AsPrimitive<f32>
|
||||
+ num_traits::FromPrimitive,
|
||||
>(
|
||||
src: &[T],
|
||||
layout: &Layout,
|
||||
alpha: &[T],
|
||||
alpha_layout: &Layout,
|
||||
beta: &[T],
|
||||
beta_layout: &Layout,
|
||||
eps: f32,
|
||||
) -> Result<(CpuStorage, Shape)> {
|
||||
let src = match layout.contiguous_offsets() {
|
||||
None => candle::bail!("input has to be contiguous"),
|
||||
Some((o1, o2)) => &src[o1..o2],
|
||||
};
|
||||
let alpha = match alpha_layout.contiguous_offsets() {
|
||||
None => candle::bail!("alpha has to be contiguous"),
|
||||
Some((o1, o2)) => &alpha[o1..o2],
|
||||
};
|
||||
let beta = match beta_layout.contiguous_offsets() {
|
||||
None => candle::bail!("beta has to be contiguous"),
|
||||
Some((o1, o2)) => &beta[o1..o2],
|
||||
};
|
||||
let el_count = layout.shape().elem_count();
|
||||
let dims = layout.shape().dims();
|
||||
let dim_m1 = dims[dims.len() - 1];
|
||||
let mut dst = vec![T::zero(); el_count];
|
||||
src.par_chunks(dim_m1)
|
||||
.zip(dst.par_chunks_mut(dim_m1))
|
||||
.for_each(|(src, dst)| {
|
||||
let mut sum = 0f32;
|
||||
let mut sum2 = 0f32;
|
||||
for v in src {
|
||||
let v = v.as_();
|
||||
sum += v;
|
||||
sum2 += v * v;
|
||||
}
|
||||
let mean = sum / dim_m1 as f32;
|
||||
let var = sum2 / dim_m1 as f32 - mean * mean;
|
||||
let inv_std = (var + eps).sqrt().recip();
|
||||
for ((d, s), (alpha, beta)) in
|
||||
dst.iter_mut().zip(src.iter()).zip(alpha.iter().zip(beta))
|
||||
{
|
||||
let alpha = alpha.as_();
|
||||
let beta = beta.as_();
|
||||
let d_ = (s.as_() - mean) * inv_std * alpha + beta;
|
||||
*d = T::from_f32(d_).unwrap_or_else(T::nan);
|
||||
}
|
||||
});
|
||||
let storage = candle::WithDType::to_cpu_storage_owned(dst);
|
||||
Ok((storage, Shape::from_dims(dims)))
|
||||
}
|
||||
|
||||
use CpuStorage as C;
|
||||
match (s1, s2, s3) {
|
||||
(C::BF16(s1), C::BF16(s2), C::BF16(s3)) => {
|
||||
inner::<half::bf16>(s1, l1, s2, l2, s3, l3, eps)
|
||||
}
|
||||
(C::F16(s1), C::F16(s2), C::F16(s3)) => inner::<half::f16>(s1, l1, s2, l2, s3, l3, eps),
|
||||
(C::F32(s1), C::F32(s2), C::F32(s3)) => inner::<f32>(s1, l1, s2, l2, s3, l3, eps),
|
||||
_ => candle::bail!("unsupported dtype for rmsnorm {:?}", s1.dtype()),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
fn cuda_fwd(
|
||||
&self,
|
||||
s1: &candle::CudaStorage,
|
||||
l1: &Layout,
|
||||
s2: &candle::CudaStorage,
|
||||
l2: &Layout,
|
||||
s3: &candle::CudaStorage,
|
||||
l3: &Layout,
|
||||
) -> Result<(candle::CudaStorage, Shape)> {
|
||||
use candle::cuda_backend::cudarc::driver::{
|
||||
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
|
||||
};
|
||||
use candle::cuda_backend::{kernel_name, kernels, Map3, WrapErr};
|
||||
use candle::{CudaDevice, WithDType};
|
||||
|
||||
struct S {
|
||||
eps: f32,
|
||||
}
|
||||
impl Map3 for S {
|
||||
fn f<T: DeviceRepr + WithDType>(
|
||||
&self,
|
||||
src: &CudaSlice<T>,
|
||||
layout: &Layout,
|
||||
alpha: &CudaSlice<T>,
|
||||
alpha_layout: &Layout,
|
||||
beta: &CudaSlice<T>,
|
||||
beta_layout: &Layout,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<CudaSlice<T>> {
|
||||
let src = match layout.contiguous_offsets() {
|
||||
None => candle::bail!("input has to be contiguous"),
|
||||
Some((o1, o2)) => src.slice(o1..o2),
|
||||
};
|
||||
let alpha = match alpha_layout.contiguous_offsets() {
|
||||
None => candle::bail!("alpha has to be contiguous"),
|
||||
Some((o1, o2)) => alpha.slice(o1..o2),
|
||||
};
|
||||
let beta = match beta_layout.contiguous_offsets() {
|
||||
None => candle::bail!("beta has to be contiguous"),
|
||||
Some((o1, o2)) => beta.slice(o1..o2),
|
||||
};
|
||||
let el = layout.shape().elem_count();
|
||||
let dims = layout.shape().dims();
|
||||
let dim_m1 = dims[dims.len() - 1];
|
||||
let (n_rows, n_cols) = (el / dim_m1, dim_m1);
|
||||
|
||||
let cfg = LaunchConfig {
|
||||
grid_dim: (n_rows as u32, 1, 1),
|
||||
block_dim: (1024, 1, 1),
|
||||
shared_mem_bytes: 0,
|
||||
};
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("layernorm"), kernels::REDUCE)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
|
||||
let params = (&src, &dst, &alpha, &beta, n_cols as i32, self.eps);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(dst)
|
||||
}
|
||||
}
|
||||
|
||||
use candle::backend::BackendStorage;
|
||||
let dev = s1.device();
|
||||
let slice = S { eps: self.eps }.map(&s1.slice, l1, &s2.slice, l2, &s3.slice, l3, dev)?;
|
||||
let dst = candle::cuda_backend::CudaStorage {
|
||||
slice,
|
||||
device: dev.clone(),
|
||||
};
|
||||
Ok((dst, l1.shape().clone()))
|
||||
}
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
fn metal_fwd(
|
||||
&self,
|
||||
s1: &candle::MetalStorage,
|
||||
l1: &Layout,
|
||||
s2: &candle::MetalStorage,
|
||||
l2: &Layout,
|
||||
s3: &candle::MetalStorage,
|
||||
l3: &Layout,
|
||||
) -> Result<(candle::MetalStorage, Shape)> {
|
||||
use candle::backend::BackendStorage;
|
||||
let device = s1.device();
|
||||
let command_buffer = device.command_buffer()?;
|
||||
let kernels = device.kernels();
|
||||
let name = match (s1.dtype(), s2.dtype(), s3.dtype()) {
|
||||
(DType::F32, DType::F32, DType::F32) => "layernorm_f32",
|
||||
(DType::F16, DType::F16, DType::F16) => "layernorm_f16",
|
||||
(DType::BF16, DType::BF16, DType::BF16) => "layernorm_bf16",
|
||||
(dt1, dt2, dt3) => {
|
||||
candle::bail!("layernorm is not implemented for {dt1:?} {dt2:?} {dt3:?}")
|
||||
}
|
||||
};
|
||||
|
||||
if !(l1.is_contiguous() && l2.is_contiguous() && l3.is_contiguous()) {
|
||||
candle::bail!("Non contiguous layernorm is not implemented");
|
||||
}
|
||||
|
||||
let last_dim = l1.dims()[l1.shape().rank() - 1];
|
||||
let elem_count = l1.shape().elem_count();
|
||||
let output = device.new_buffer(elem_count, s1.dtype(), "layernorm")?;
|
||||
candle_metal_kernels::call_layer_norm(
|
||||
device.metal_device(),
|
||||
&command_buffer,
|
||||
kernels,
|
||||
name,
|
||||
elem_count,
|
||||
last_dim,
|
||||
self.eps,
|
||||
s1.buffer(),
|
||||
l1.start_offset() * s1.dtype().size_in_bytes(),
|
||||
s2.buffer(),
|
||||
l2.start_offset() * s2.dtype().size_in_bytes(),
|
||||
s3.buffer(),
|
||||
l3.start_offset() * s3.dtype().size_in_bytes(),
|
||||
&output,
|
||||
)
|
||||
.map_err(candle::Error::wrap)?;
|
||||
let newstorage = candle::MetalStorage::new(output, device.clone(), elem_count, s1.dtype());
|
||||
Ok((newstorage, l1.shape().clone()))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn layer_norm_slow(x: &Tensor, alpha: &Tensor, beta: &Tensor, eps: f32) -> Result<Tensor> {
|
||||
let x_dtype = x.dtype();
|
||||
let internal_dtype = match x_dtype {
|
||||
DType::F16 | DType::BF16 => DType::F32,
|
||||
d => d,
|
||||
};
|
||||
let hidden_size = x.dim(D::Minus1)?;
|
||||
let x = x.to_dtype(internal_dtype)?;
|
||||
let x = {
|
||||
let mean_x = (x.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
|
||||
x.broadcast_sub(&mean_x)?
|
||||
};
|
||||
let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
|
||||
let x_normed = x.broadcast_div(&(norm_x + eps as f64)?.sqrt()?)?;
|
||||
x_normed
|
||||
.to_dtype(x_dtype)?
|
||||
.broadcast_mul(alpha)?
|
||||
.broadcast_add(beta)
|
||||
}
|
||||
|
||||
pub fn layer_norm(xs: &Tensor, alpha: &Tensor, beta: &Tensor, eps: f32) -> Result<Tensor> {
|
||||
let hidden_size_xs = xs.dim(D::Minus1)?;
|
||||
let hidden_size_alpha = alpha.dims1()?;
|
||||
let hidden_size_beta = beta.dims1()?;
|
||||
if hidden_size_xs != hidden_size_alpha || hidden_size_xs != hidden_size_beta {
|
||||
candle::bail!(
|
||||
"shape mismatch in layer-norm src: {:?} alpha: {:?} beta: {:?}",
|
||||
xs.shape(),
|
||||
alpha.shape(),
|
||||
beta.shape()
|
||||
)
|
||||
}
|
||||
xs.apply_op3_no_bwd(alpha, beta, &LayerNorm { eps })
|
||||
}
|
||||
|
||||
// https://pytorch.org/docs/stable/generated/torch.nn.PixelShuffle.html
|
||||
pub fn pixel_shuffle(xs: &Tensor, upscale_factor: usize) -> Result<Tensor> {
|
||||
let (b_size, c, h, w) = xs.dims4()?;
|
||||
|
Reference in New Issue
Block a user