mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Custom op for RmsNorm (#1890)
* Trying out a custom RmsNorm cuda kernel. * CPU implementation for rms-norm. * Cuda wrappers. * Add some validation. * Add some testing. * More testing.
This commit is contained in:
@ -2,6 +2,7 @@
|
|||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
|
|
||||||
|
#define WARP_SIZE 32
|
||||||
const int BLOCK_SIZE = 1024;
|
const int BLOCK_SIZE = 1024;
|
||||||
|
|
||||||
// TODO: Maybe add some fast_sum_f16_f32 variant that not only accumulate in f32
|
// TODO: Maybe add some fast_sum_f16_f32 variant that not only accumulate in f32
|
||||||
@ -49,6 +50,59 @@ fast_sum(const size_t src_numel, const size_t el_to_sum_per_block,
|
|||||||
dst[dst_id] = shr[0];
|
dst[dst_id] = shr[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ float warp_reduce_sum(float x) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||||
|
x += __shfl_xor_sync(0xffffffff, x, mask, 32);
|
||||||
|
}
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
|
||||||
|
// RmsNorm implementation adapted from ggml, accumulation is made using f32.
|
||||||
|
// https://github.com/ggerganov/llama.cpp/blob/d59bd97065cd7ded6c4ecab54b1d5e0b1b11e318/ggml-cuda.cu#L523
|
||||||
|
template <typename T>
|
||||||
|
__device__ void rmsnorm(const T * x, T * dst, const T * alpha, const int ncols, const float eps) {
|
||||||
|
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
||||||
|
const int tid = threadIdx.x;
|
||||||
|
const int block_size = blockDim.x;
|
||||||
|
|
||||||
|
float tmp = 0.0f; // partial sum for thread in warp
|
||||||
|
|
||||||
|
for (int col = tid; col < ncols; col += block_size) {
|
||||||
|
const float xi = static_cast<float>(x[row*ncols + col]);
|
||||||
|
tmp += xi * xi;
|
||||||
|
}
|
||||||
|
|
||||||
|
// sum up partial sums
|
||||||
|
tmp = warp_reduce_sum(tmp);
|
||||||
|
if (block_size > WARP_SIZE) {
|
||||||
|
__shared__ float s_sum[32];
|
||||||
|
int warp_id = threadIdx.x / WARP_SIZE;
|
||||||
|
int lane_id = threadIdx.x % WARP_SIZE;
|
||||||
|
if (lane_id == 0) {
|
||||||
|
s_sum[warp_id] = tmp;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
tmp = s_sum[lane_id];
|
||||||
|
tmp = warp_reduce_sum(tmp);
|
||||||
|
}
|
||||||
|
|
||||||
|
const float mean = tmp / ncols;
|
||||||
|
const float scale = rsqrtf(mean + eps);
|
||||||
|
|
||||||
|
if (alpha == nullptr) {
|
||||||
|
for (int col = tid; col < ncols; col += block_size) {
|
||||||
|
dst[row*ncols + col] = static_cast<T>(scale * static_cast<float>(x[row*ncols + col]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
for (int col = tid; col < ncols; col += block_size) {
|
||||||
|
float a = static_cast<float>(alpha[col]);
|
||||||
|
dst[row*ncols + col] = static_cast<T>(scale * static_cast<float>(x[row*ncols + col]) * a);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Softmax implementation adapted from ggml.
|
// Softmax implementation adapted from ggml.
|
||||||
// https://github.com/ggerganov/llama.cpp/blob/d59bd97065cd7ded6c4ecab54b1d5e0b1b11e318/ggml-cuda.cu#L4159
|
// https://github.com/ggerganov/llama.cpp/blob/d59bd97065cd7ded6c4ecab54b1d5e0b1b11e318/ggml-cuda.cu#L4159
|
||||||
template <typename T, typename ACC>
|
template <typename T, typename ACC>
|
||||||
@ -341,14 +395,23 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block,
|
|||||||
softmax<TYPENAME, ACC_TYPENAME>(src, dst, n_cols); \
|
softmax<TYPENAME, ACC_TYPENAME>(src, dst, n_cols); \
|
||||||
} \
|
} \
|
||||||
|
|
||||||
|
#define RMSNORM_OP(TYPENAME, FN_NAME) \
|
||||||
|
extern "C" __global__ void FN_NAME( \
|
||||||
|
const TYPENAME *src, TYPENAME *dst, const TYPENAME *alpha, \
|
||||||
|
const int n_cols, const float eps) { \
|
||||||
|
rmsnorm<TYPENAME>(src, dst, alpha, n_cols, eps); \
|
||||||
|
} \
|
||||||
|
|
||||||
#if __CUDA_ARCH__ >= 800
|
#if __CUDA_ARCH__ >= 800
|
||||||
SOFTMAX_OP(__nv_bfloat16, float, softmax_bf16)
|
SOFTMAX_OP(__nv_bfloat16, float, softmax_bf16)
|
||||||
|
RMSNORM_OP(__nv_bfloat16, rmsnorm_bf16)
|
||||||
SUM_OP(__nv_bfloat16, sum_bf16)
|
SUM_OP(__nv_bfloat16, sum_bf16)
|
||||||
FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_argmin_bf16, fast_argmax_bf16, fast_sum_bf16)
|
FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_argmin_bf16, fast_argmax_bf16, fast_sum_bf16)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if __CUDA_ARCH__ >= 530
|
#if __CUDA_ARCH__ >= 530
|
||||||
SOFTMAX_OP(__half, float, softmax_f16)
|
SOFTMAX_OP(__half, float, softmax_f16)
|
||||||
|
RMSNORM_OP(__half, rmsnorm_f16)
|
||||||
SUM_OP(__half, sum_f16)
|
SUM_OP(__half, sum_f16)
|
||||||
FAST_OP(__half, fast_min_f16, fast_max_f16, fast_argmin_f16, fast_argmax_f16, fast_sum_f16)
|
FAST_OP(__half, fast_min_f16, fast_max_f16, fast_argmin_f16, fast_argmax_f16, fast_sum_f16)
|
||||||
#endif
|
#endif
|
||||||
@ -358,6 +421,8 @@ SUM_OP(double, sum_f64)
|
|||||||
SUM_OP(uint32_t, sum_u32)
|
SUM_OP(uint32_t, sum_u32)
|
||||||
SOFTMAX_OP(float, float, softmax_f32)
|
SOFTMAX_OP(float, float, softmax_f32)
|
||||||
SOFTMAX_OP(double, double, softmax_f64)
|
SOFTMAX_OP(double, double, softmax_f64)
|
||||||
|
RMSNORM_OP(float, rmsnorm_f32)
|
||||||
|
RMSNORM_OP(double, rmsnorm_f64)
|
||||||
|
|
||||||
FAST_OP(float, fast_min_f32, fast_max_f32, fast_argmin_f32, fast_argmax_f32, fast_sum_f32)
|
FAST_OP(float, fast_min_f32, fast_max_f32, fast_argmin_f32, fast_argmax_f32, fast_sum_f32)
|
||||||
FAST_OP(double, fast_min_f64, fast_max_f64, fast_argmin_f64, fast_argmax_f64, fast_sum_f64)
|
FAST_OP(double, fast_min_f64, fast_max_f64, fast_argmin_f64, fast_argmax_f64, fast_sum_f64)
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
use candle::{CpuStorage, Layout, Result, Shape, Tensor};
|
use candle::{CpuStorage, DType, Layout, Result, Shape, Tensor};
|
||||||
use rayon::prelude::*;
|
use rayon::prelude::*;
|
||||||
|
|
||||||
/// Applies the softmax function to the input tensor, rescaling the element so that elements on
|
/// Applies the softmax function to the input tensor, rescaling the element so that elements on
|
||||||
@ -180,11 +180,10 @@ impl candle::CustomOp1 for SoftmaxLastDim {
|
|||||||
block_dim: (1, 32, 1),
|
block_dim: (1, 32, 1),
|
||||||
shared_mem_bytes: 0,
|
shared_mem_bytes: 0,
|
||||||
};
|
};
|
||||||
let src = &src.slice(layout.start_offset()..);
|
|
||||||
let func = dev.get_or_load_func(&kernel_name::<T>("softmax"), kernels::REDUCE)?;
|
let func = dev.get_or_load_func(&kernel_name::<T>("softmax"), kernels::REDUCE)?;
|
||||||
// SAFETY: Set later by running the kernel.
|
// SAFETY: Set later by running the kernel.
|
||||||
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
|
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
|
||||||
let params = (src, &dst, n_cols as i32);
|
let params = (&src, &dst, n_cols as i32);
|
||||||
// SAFETY: ffi.
|
// SAFETY: ffi.
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
unsafe { func.launch(cfg, params) }.w()?;
|
||||||
Ok(dst)
|
Ok(dst)
|
||||||
@ -207,7 +206,7 @@ impl candle::CustomOp1 for SoftmaxLastDim {
|
|||||||
storage: &candle::MetalStorage,
|
storage: &candle::MetalStorage,
|
||||||
layout: &Layout,
|
layout: &Layout,
|
||||||
) -> Result<(candle::MetalStorage, Shape)> {
|
) -> Result<(candle::MetalStorage, Shape)> {
|
||||||
use candle::{backend::BackendStorage, DType};
|
use candle::backend::BackendStorage;
|
||||||
let device = storage.device();
|
let device = storage.device();
|
||||||
let command_buffer = device.command_buffer()?;
|
let command_buffer = device.command_buffer()?;
|
||||||
let kernels = device.kernels();
|
let kernels = device.kernels();
|
||||||
@ -248,6 +247,170 @@ pub fn softmax_last_dim(xs: &Tensor) -> Result<Tensor> {
|
|||||||
xs.apply_op1_no_bwd(&SoftmaxLastDim)
|
xs.apply_op1_no_bwd(&SoftmaxLastDim)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct RmsNorm {
|
||||||
|
eps: f32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl candle::CustomOp2 for RmsNorm {
|
||||||
|
fn name(&self) -> &'static str {
|
||||||
|
"rms-norm"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn cpu_fwd(
|
||||||
|
&self,
|
||||||
|
s1: &CpuStorage,
|
||||||
|
l1: &Layout,
|
||||||
|
s2: &CpuStorage,
|
||||||
|
l2: &Layout,
|
||||||
|
) -> Result<(CpuStorage, Shape)> {
|
||||||
|
use candle::backend::BackendStorage;
|
||||||
|
|
||||||
|
let eps = self.eps;
|
||||||
|
fn inner<
|
||||||
|
T: candle::WithDType
|
||||||
|
+ num_traits::Float
|
||||||
|
+ num_traits::AsPrimitive<f32>
|
||||||
|
+ num_traits::FromPrimitive,
|
||||||
|
>(
|
||||||
|
src: &[T],
|
||||||
|
layout: &Layout,
|
||||||
|
alpha: &[T],
|
||||||
|
alpha_layout: &Layout,
|
||||||
|
eps: f32,
|
||||||
|
) -> Result<(CpuStorage, Shape)> {
|
||||||
|
let src = match layout.contiguous_offsets() {
|
||||||
|
None => candle::bail!("input has to be contiguous"),
|
||||||
|
Some((o1, o2)) => &src[o1..o2],
|
||||||
|
};
|
||||||
|
let alpha = match alpha_layout.contiguous_offsets() {
|
||||||
|
None => candle::bail!("alpha has to be contiguous"),
|
||||||
|
Some((o1, o2)) => &alpha[o1..o2],
|
||||||
|
};
|
||||||
|
let el_count = layout.shape().elem_count();
|
||||||
|
let dims = layout.shape().dims();
|
||||||
|
let dim_m1 = dims[dims.len() - 1];
|
||||||
|
let mut dst = vec![T::zero(); el_count];
|
||||||
|
src.par_chunks(dim_m1)
|
||||||
|
.zip(dst.par_chunks_mut(dim_m1))
|
||||||
|
.for_each(|(src, dst)| {
|
||||||
|
let sum2 = src
|
||||||
|
.iter()
|
||||||
|
.map(|&v| {
|
||||||
|
let v = v.as_();
|
||||||
|
v * v
|
||||||
|
})
|
||||||
|
.sum::<f32>();
|
||||||
|
let m = (sum2 / dim_m1 as f32 + eps).sqrt();
|
||||||
|
let m = T::from_f32(m).unwrap_or_else(T::nan);
|
||||||
|
for ((d, s), alpha) in dst.iter_mut().zip(src.iter()).zip(alpha) {
|
||||||
|
*d = *s / m * *alpha
|
||||||
|
}
|
||||||
|
});
|
||||||
|
let storage = candle::WithDType::to_cpu_storage_owned(dst);
|
||||||
|
Ok((storage, Shape::from_dims(dims)))
|
||||||
|
}
|
||||||
|
|
||||||
|
use CpuStorage as C;
|
||||||
|
match (s1, s2) {
|
||||||
|
(C::BF16(s1), C::BF16(s2)) => inner::<half::bf16>(s1, l1, s2, l2, eps),
|
||||||
|
(C::F16(s1), C::F16(s2)) => inner::<half::f16>(s1, l1, s2, l2, eps),
|
||||||
|
(C::F32(s1), C::F32(s2)) => inner::<f32>(s1, l1, s2, l2, eps),
|
||||||
|
_ => candle::bail!("unsupported dtype for rmsnorm {:?}", s1.dtype()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
fn cuda_fwd(
|
||||||
|
&self,
|
||||||
|
s1: &candle::CudaStorage,
|
||||||
|
l1: &Layout,
|
||||||
|
s2: &candle::CudaStorage,
|
||||||
|
l2: &Layout,
|
||||||
|
) -> Result<(candle::CudaStorage, Shape)> {
|
||||||
|
use candle::cuda_backend::cudarc::driver::{
|
||||||
|
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
|
||||||
|
};
|
||||||
|
use candle::cuda_backend::{kernel_name, kernels, Map2, WrapErr};
|
||||||
|
use candle::{CudaDevice, WithDType};
|
||||||
|
|
||||||
|
struct S {
|
||||||
|
eps: f32,
|
||||||
|
}
|
||||||
|
impl Map2 for S {
|
||||||
|
fn f<T: DeviceRepr + WithDType>(
|
||||||
|
&self,
|
||||||
|
src: &CudaSlice<T>,
|
||||||
|
layout: &Layout,
|
||||||
|
alpha: &CudaSlice<T>,
|
||||||
|
alpha_layout: &Layout,
|
||||||
|
dev: &CudaDevice,
|
||||||
|
) -> Result<CudaSlice<T>> {
|
||||||
|
let src = match layout.contiguous_offsets() {
|
||||||
|
None => candle::bail!("input has to be contiguous"),
|
||||||
|
Some((o1, o2)) => src.slice(o1..o2),
|
||||||
|
};
|
||||||
|
let alpha = match alpha_layout.contiguous_offsets() {
|
||||||
|
None => candle::bail!("alpha has to be contiguous"),
|
||||||
|
Some((o1, o2)) => alpha.slice(o1..o2),
|
||||||
|
};
|
||||||
|
let el = layout.shape().elem_count();
|
||||||
|
let dims = layout.shape().dims();
|
||||||
|
let dim_m1 = dims[dims.len() - 1];
|
||||||
|
let (n_rows, n_cols) = (el / dim_m1, dim_m1);
|
||||||
|
|
||||||
|
let cfg = LaunchConfig {
|
||||||
|
grid_dim: (n_rows as u32, 1, 1),
|
||||||
|
block_dim: (1024, 1, 1),
|
||||||
|
shared_mem_bytes: 0,
|
||||||
|
};
|
||||||
|
let func = dev.get_or_load_func(&kernel_name::<T>("rmsnorm"), kernels::REDUCE)?;
|
||||||
|
// SAFETY: Set later by running the kernel.
|
||||||
|
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
|
||||||
|
let params = (&src, &dst, &alpha, n_cols as i32, self.eps);
|
||||||
|
// SAFETY: ffi.
|
||||||
|
unsafe { func.launch(cfg, params) }.w()?;
|
||||||
|
Ok(dst)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
use candle::backend::BackendStorage;
|
||||||
|
let dev = s1.device();
|
||||||
|
let slice = S { eps: self.eps }.map(&s1.slice, l1, &s2.slice, l2, dev)?;
|
||||||
|
let dst = candle::cuda_backend::CudaStorage {
|
||||||
|
slice,
|
||||||
|
device: dev.clone(),
|
||||||
|
};
|
||||||
|
Ok((dst, l1.shape().clone()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn rms_norm_slow(x: &Tensor, alpha: &Tensor, eps: f32) -> Result<Tensor> {
|
||||||
|
let x_dtype = x.dtype();
|
||||||
|
let internal_dtype = match x_dtype {
|
||||||
|
DType::F16 | DType::BF16 => DType::F32,
|
||||||
|
d => d,
|
||||||
|
};
|
||||||
|
let hidden_size = x.dim(candle::D::Minus1)?;
|
||||||
|
let x = x.to_dtype(internal_dtype)?;
|
||||||
|
let norm_x = (x.sqr()?.sum_keepdim(candle::D::Minus1)? / hidden_size as f64)?;
|
||||||
|
let x_normed = x.broadcast_div(&(norm_x + eps as f64)?.sqrt()?)?;
|
||||||
|
x_normed.to_dtype(x_dtype)?.broadcast_mul(alpha)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn rms_norm(xs: &Tensor, alpha: &Tensor, eps: f32) -> Result<Tensor> {
|
||||||
|
let hidden_size_xs = xs.dim(candle::D::Minus1)?;
|
||||||
|
let hidden_size_alpha = alpha.dims1()?;
|
||||||
|
if hidden_size_xs != hidden_size_alpha {
|
||||||
|
candle::bail!(
|
||||||
|
"shape mismatch in rms-norm {:?} {:?}",
|
||||||
|
xs.shape(),
|
||||||
|
alpha.shape()
|
||||||
|
)
|
||||||
|
}
|
||||||
|
xs.apply_op2_no_bwd(alpha, &RmsNorm { eps })
|
||||||
|
}
|
||||||
|
|
||||||
// https://pytorch.org/docs/stable/generated/torch.nn.PixelShuffle.html
|
// https://pytorch.org/docs/stable/generated/torch.nn.PixelShuffle.html
|
||||||
pub fn pixel_shuffle(xs: &Tensor, upscale_factor: usize) -> Result<Tensor> {
|
pub fn pixel_shuffle(xs: &Tensor, upscale_factor: usize) -> Result<Tensor> {
|
||||||
let (b_size, c, h, w) = xs.dims4()?;
|
let (b_size, c, h, w) = xs.dims4()?;
|
||||||
|
@ -4,11 +4,9 @@ extern crate intel_mkl_src;
|
|||||||
#[cfg(feature = "accelerate")]
|
#[cfg(feature = "accelerate")]
|
||||||
extern crate accelerate_src;
|
extern crate accelerate_src;
|
||||||
|
|
||||||
use candle::{test_utils::to_vec3_round, Device, Result, Tensor};
|
use candle::{test_device, test_utils::to_vec3_round, Device, Result, Tensor};
|
||||||
|
|
||||||
#[test]
|
fn softmax(device: &Device) -> Result<()> {
|
||||||
fn softmax() -> Result<()> {
|
|
||||||
let device = &Device::Cpu;
|
|
||||||
let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]];
|
let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]];
|
||||||
let tensor = Tensor::new(data, device)?;
|
let tensor = Tensor::new(data, device)?;
|
||||||
let t0 = candle_nn::ops::softmax(&tensor.log()?, 0)?;
|
let t0 = candle_nn::ops::softmax(&tensor.log()?, 0)?;
|
||||||
@ -54,6 +52,31 @@ fn softmax() -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn rms_norm(device: &Device) -> Result<()> {
|
||||||
|
let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]];
|
||||||
|
let tensor = Tensor::new(data, device)?;
|
||||||
|
let alpha = Tensor::new(&[1f32, 2f32, 3f32], device)?;
|
||||||
|
let t = candle_nn::ops::rms_norm(&tensor, &alpha, 1e-5)?;
|
||||||
|
assert_eq!(
|
||||||
|
to_vec3_round(&t, 4)?,
|
||||||
|
&[
|
||||||
|
[[1.019, 0.6794, 4.0762], [0.1674, 1.6744, 4.521]],
|
||||||
|
[[0.4714, 0.4714, 4.9497], [1.206, 0.603, 3.6181]]
|
||||||
|
]
|
||||||
|
);
|
||||||
|
let t2 = candle_nn::ops::rms_norm_slow(&tensor, &alpha, 1e-5)?;
|
||||||
|
assert_eq!(
|
||||||
|
to_vec3_round(&t2, 4)?,
|
||||||
|
&[
|
||||||
|
[[1.019, 0.6794, 4.0762], [0.1674, 1.6744, 4.521]],
|
||||||
|
[[0.4714, 0.4714, 4.9497], [1.206, 0.603, 3.6181]]
|
||||||
|
]
|
||||||
|
);
|
||||||
|
let diff = (t - t2)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||||
|
assert!(diff < 1e-5);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn softmax_numerical_stability() -> Result<()> {
|
fn softmax_numerical_stability() -> Result<()> {
|
||||||
let dev = &Device::Cpu;
|
let dev = &Device::Cpu;
|
||||||
@ -62,3 +85,6 @@ fn softmax_numerical_stability() -> Result<()> {
|
|||||||
assert_eq!(softmax.to_vec1::<f32>()?, &[1f32, 0.]);
|
assert_eq!(softmax.to_vec1::<f32>()?, &[1f32, 0.]);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test_device!(softmax, softmax_cpu, softmax_gpu, softmax_metal);
|
||||||
|
test_device!(rms_norm, rms_norm_cpu, rms_norm_gpu, rms_norm_metal);
|
||||||
|
Reference in New Issue
Block a user