diff --git a/candle-kernels/src/reduce.cu b/candle-kernels/src/reduce.cu index fca6865e..19fb213a 100644 --- a/candle-kernels/src/reduce.cu +++ b/candle-kernels/src/reduce.cu @@ -2,6 +2,7 @@ #include #include +#define WARP_SIZE 32 const int BLOCK_SIZE = 1024; // TODO: Maybe add some fast_sum_f16_f32 variant that not only accumulate in f32 @@ -49,6 +50,59 @@ fast_sum(const size_t src_numel, const size_t el_to_sum_per_block, dst[dst_id] = shr[0]; } +static __device__ __forceinline__ float warp_reduce_sum(float x) { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + x += __shfl_xor_sync(0xffffffff, x, mask, 32); + } + return x; +} + +// RmsNorm implementation adapted from ggml, accumulation is made using f32. +// https://github.com/ggerganov/llama.cpp/blob/d59bd97065cd7ded6c4ecab54b1d5e0b1b11e318/ggml-cuda.cu#L523 +template +__device__ void rmsnorm(const T * x, T * dst, const T * alpha, const int ncols, const float eps) { + const int row = blockIdx.x*blockDim.y + threadIdx.y; + const int tid = threadIdx.x; + const int block_size = blockDim.x; + + float tmp = 0.0f; // partial sum for thread in warp + + for (int col = tid; col < ncols; col += block_size) { + const float xi = static_cast(x[row*ncols + col]); + tmp += xi * xi; + } + + // sum up partial sums + tmp = warp_reduce_sum(tmp); + if (block_size > WARP_SIZE) { + __shared__ float s_sum[32]; + int warp_id = threadIdx.x / WARP_SIZE; + int lane_id = threadIdx.x % WARP_SIZE; + if (lane_id == 0) { + s_sum[warp_id] = tmp; + } + __syncthreads(); + tmp = s_sum[lane_id]; + tmp = warp_reduce_sum(tmp); + } + + const float mean = tmp / ncols; + const float scale = rsqrtf(mean + eps); + + if (alpha == nullptr) { + for (int col = tid; col < ncols; col += block_size) { + dst[row*ncols + col] = static_cast(scale * static_cast(x[row*ncols + col])); + } + } + else { + for (int col = tid; col < ncols; col += block_size) { + float a = static_cast(alpha[col]); + dst[row*ncols + col] = static_cast(scale * static_cast(x[row*ncols + col]) * a); + } + } +} + // Softmax implementation adapted from ggml. // https://github.com/ggerganov/llama.cpp/blob/d59bd97065cd7ded6c4ecab54b1d5e0b1b11e318/ggml-cuda.cu#L4159 template @@ -341,14 +395,23 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block, softmax(src, dst, n_cols); \ } \ +#define RMSNORM_OP(TYPENAME, FN_NAME) \ + extern "C" __global__ void FN_NAME( \ + const TYPENAME *src, TYPENAME *dst, const TYPENAME *alpha, \ + const int n_cols, const float eps) { \ + rmsnorm(src, dst, alpha, n_cols, eps); \ + } \ + #if __CUDA_ARCH__ >= 800 SOFTMAX_OP(__nv_bfloat16, float, softmax_bf16) +RMSNORM_OP(__nv_bfloat16, rmsnorm_bf16) SUM_OP(__nv_bfloat16, sum_bf16) FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_argmin_bf16, fast_argmax_bf16, fast_sum_bf16) #endif #if __CUDA_ARCH__ >= 530 SOFTMAX_OP(__half, float, softmax_f16) +RMSNORM_OP(__half, rmsnorm_f16) SUM_OP(__half, sum_f16) FAST_OP(__half, fast_min_f16, fast_max_f16, fast_argmin_f16, fast_argmax_f16, fast_sum_f16) #endif @@ -358,6 +421,8 @@ SUM_OP(double, sum_f64) SUM_OP(uint32_t, sum_u32) SOFTMAX_OP(float, float, softmax_f32) SOFTMAX_OP(double, double, softmax_f64) +RMSNORM_OP(float, rmsnorm_f32) +RMSNORM_OP(double, rmsnorm_f64) FAST_OP(float, fast_min_f32, fast_max_f32, fast_argmin_f32, fast_argmax_f32, fast_sum_f32) FAST_OP(double, fast_min_f64, fast_max_f64, fast_argmin_f64, fast_argmax_f64, fast_sum_f64) diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index 88d1b3d6..d725bdc2 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -1,4 +1,4 @@ -use candle::{CpuStorage, Layout, Result, Shape, Tensor}; +use candle::{CpuStorage, DType, Layout, Result, Shape, Tensor}; use rayon::prelude::*; /// Applies the softmax function to the input tensor, rescaling the element so that elements on @@ -180,11 +180,10 @@ impl candle::CustomOp1 for SoftmaxLastDim { block_dim: (1, 32, 1), shared_mem_bytes: 0, }; - let src = &src.slice(layout.start_offset()..); let func = dev.get_or_load_func(&kernel_name::("softmax"), kernels::REDUCE)?; // SAFETY: Set later by running the kernel. let dst = unsafe { dev.alloc::(el) }.w()?; - let params = (src, &dst, n_cols as i32); + let params = (&src, &dst, n_cols as i32); // SAFETY: ffi. unsafe { func.launch(cfg, params) }.w()?; Ok(dst) @@ -207,7 +206,7 @@ impl candle::CustomOp1 for SoftmaxLastDim { storage: &candle::MetalStorage, layout: &Layout, ) -> Result<(candle::MetalStorage, Shape)> { - use candle::{backend::BackendStorage, DType}; + use candle::backend::BackendStorage; let device = storage.device(); let command_buffer = device.command_buffer()?; let kernels = device.kernels(); @@ -248,6 +247,170 @@ pub fn softmax_last_dim(xs: &Tensor) -> Result { xs.apply_op1_no_bwd(&SoftmaxLastDim) } +#[derive(Debug, Clone)] +struct RmsNorm { + eps: f32, +} + +impl candle::CustomOp2 for RmsNorm { + fn name(&self) -> &'static str { + "rms-norm" + } + + fn cpu_fwd( + &self, + s1: &CpuStorage, + l1: &Layout, + s2: &CpuStorage, + l2: &Layout, + ) -> Result<(CpuStorage, Shape)> { + use candle::backend::BackendStorage; + + let eps = self.eps; + fn inner< + T: candle::WithDType + + num_traits::Float + + num_traits::AsPrimitive + + num_traits::FromPrimitive, + >( + src: &[T], + layout: &Layout, + alpha: &[T], + alpha_layout: &Layout, + eps: f32, + ) -> Result<(CpuStorage, Shape)> { + let src = match layout.contiguous_offsets() { + None => candle::bail!("input has to be contiguous"), + Some((o1, o2)) => &src[o1..o2], + }; + let alpha = match alpha_layout.contiguous_offsets() { + None => candle::bail!("alpha has to be contiguous"), + Some((o1, o2)) => &alpha[o1..o2], + }; + let el_count = layout.shape().elem_count(); + let dims = layout.shape().dims(); + let dim_m1 = dims[dims.len() - 1]; + let mut dst = vec![T::zero(); el_count]; + src.par_chunks(dim_m1) + .zip(dst.par_chunks_mut(dim_m1)) + .for_each(|(src, dst)| { + let sum2 = src + .iter() + .map(|&v| { + let v = v.as_(); + v * v + }) + .sum::(); + let m = (sum2 / dim_m1 as f32 + eps).sqrt(); + let m = T::from_f32(m).unwrap_or_else(T::nan); + for ((d, s), alpha) in dst.iter_mut().zip(src.iter()).zip(alpha) { + *d = *s / m * *alpha + } + }); + let storage = candle::WithDType::to_cpu_storage_owned(dst); + Ok((storage, Shape::from_dims(dims))) + } + + use CpuStorage as C; + match (s1, s2) { + (C::BF16(s1), C::BF16(s2)) => inner::(s1, l1, s2, l2, eps), + (C::F16(s1), C::F16(s2)) => inner::(s1, l1, s2, l2, eps), + (C::F32(s1), C::F32(s2)) => inner::(s1, l1, s2, l2, eps), + _ => candle::bail!("unsupported dtype for rmsnorm {:?}", s1.dtype()), + } + } + + #[cfg(feature = "cuda")] + fn cuda_fwd( + &self, + s1: &candle::CudaStorage, + l1: &Layout, + s2: &candle::CudaStorage, + l2: &Layout, + ) -> Result<(candle::CudaStorage, Shape)> { + use candle::cuda_backend::cudarc::driver::{ + CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, + }; + use candle::cuda_backend::{kernel_name, kernels, Map2, WrapErr}; + use candle::{CudaDevice, WithDType}; + + struct S { + eps: f32, + } + impl Map2 for S { + fn f( + &self, + src: &CudaSlice, + layout: &Layout, + alpha: &CudaSlice, + alpha_layout: &Layout, + dev: &CudaDevice, + ) -> Result> { + let src = match layout.contiguous_offsets() { + None => candle::bail!("input has to be contiguous"), + Some((o1, o2)) => src.slice(o1..o2), + }; + let alpha = match alpha_layout.contiguous_offsets() { + None => candle::bail!("alpha has to be contiguous"), + Some((o1, o2)) => alpha.slice(o1..o2), + }; + let el = layout.shape().elem_count(); + let dims = layout.shape().dims(); + let dim_m1 = dims[dims.len() - 1]; + let (n_rows, n_cols) = (el / dim_m1, dim_m1); + + let cfg = LaunchConfig { + grid_dim: (n_rows as u32, 1, 1), + block_dim: (1024, 1, 1), + shared_mem_bytes: 0, + }; + let func = dev.get_or_load_func(&kernel_name::("rmsnorm"), kernels::REDUCE)?; + // SAFETY: Set later by running the kernel. + let dst = unsafe { dev.alloc::(el) }.w()?; + let params = (&src, &dst, &alpha, n_cols as i32, self.eps); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()?; + Ok(dst) + } + } + + use candle::backend::BackendStorage; + let dev = s1.device(); + let slice = S { eps: self.eps }.map(&s1.slice, l1, &s2.slice, l2, dev)?; + let dst = candle::cuda_backend::CudaStorage { + slice, + device: dev.clone(), + }; + Ok((dst, l1.shape().clone())) + } +} + +pub fn rms_norm_slow(x: &Tensor, alpha: &Tensor, eps: f32) -> Result { + let x_dtype = x.dtype(); + let internal_dtype = match x_dtype { + DType::F16 | DType::BF16 => DType::F32, + d => d, + }; + let hidden_size = x.dim(candle::D::Minus1)?; + let x = x.to_dtype(internal_dtype)?; + let norm_x = (x.sqr()?.sum_keepdim(candle::D::Minus1)? / hidden_size as f64)?; + let x_normed = x.broadcast_div(&(norm_x + eps as f64)?.sqrt()?)?; + x_normed.to_dtype(x_dtype)?.broadcast_mul(alpha) +} + +pub fn rms_norm(xs: &Tensor, alpha: &Tensor, eps: f32) -> Result { + let hidden_size_xs = xs.dim(candle::D::Minus1)?; + let hidden_size_alpha = alpha.dims1()?; + if hidden_size_xs != hidden_size_alpha { + candle::bail!( + "shape mismatch in rms-norm {:?} {:?}", + xs.shape(), + alpha.shape() + ) + } + xs.apply_op2_no_bwd(alpha, &RmsNorm { eps }) +} + // https://pytorch.org/docs/stable/generated/torch.nn.PixelShuffle.html pub fn pixel_shuffle(xs: &Tensor, upscale_factor: usize) -> Result { let (b_size, c, h, w) = xs.dims4()?; diff --git a/candle-nn/tests/ops.rs b/candle-nn/tests/ops.rs index 5ca01b37..c1e3031f 100644 --- a/candle-nn/tests/ops.rs +++ b/candle-nn/tests/ops.rs @@ -4,11 +4,9 @@ extern crate intel_mkl_src; #[cfg(feature = "accelerate")] extern crate accelerate_src; -use candle::{test_utils::to_vec3_round, Device, Result, Tensor}; +use candle::{test_device, test_utils::to_vec3_round, Device, Result, Tensor}; -#[test] -fn softmax() -> Result<()> { - let device = &Device::Cpu; +fn softmax(device: &Device) -> Result<()> { let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]]; let tensor = Tensor::new(data, device)?; let t0 = candle_nn::ops::softmax(&tensor.log()?, 0)?; @@ -54,6 +52,31 @@ fn softmax() -> Result<()> { Ok(()) } +fn rms_norm(device: &Device) -> Result<()> { + let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]]; + let tensor = Tensor::new(data, device)?; + let alpha = Tensor::new(&[1f32, 2f32, 3f32], device)?; + let t = candle_nn::ops::rms_norm(&tensor, &alpha, 1e-5)?; + assert_eq!( + to_vec3_round(&t, 4)?, + &[ + [[1.019, 0.6794, 4.0762], [0.1674, 1.6744, 4.521]], + [[0.4714, 0.4714, 4.9497], [1.206, 0.603, 3.6181]] + ] + ); + let t2 = candle_nn::ops::rms_norm_slow(&tensor, &alpha, 1e-5)?; + assert_eq!( + to_vec3_round(&t2, 4)?, + &[ + [[1.019, 0.6794, 4.0762], [0.1674, 1.6744, 4.521]], + [[0.4714, 0.4714, 4.9497], [1.206, 0.603, 3.6181]] + ] + ); + let diff = (t - t2)?.abs()?.sum_all()?.to_vec0::()?; + assert!(diff < 1e-5); + Ok(()) +} + #[test] fn softmax_numerical_stability() -> Result<()> { let dev = &Device::Cpu; @@ -62,3 +85,6 @@ fn softmax_numerical_stability() -> Result<()> { assert_eq!(softmax.to_vec1::()?, &[1f32, 0.]); Ok(()) } + +test_device!(softmax, softmax_cpu, softmax_gpu, softmax_metal); +test_device!(rms_norm, rms_norm_cpu, rms_norm_gpu, rms_norm_metal);