mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Custom op for RmsNorm (#1890)
* Trying out a custom RmsNorm cuda kernel. * CPU implementation for rms-norm. * Cuda wrappers. * Add some validation. * Add some testing. * More testing.
This commit is contained in:
@ -2,6 +2,7 @@
|
||||
#include <cmath>
|
||||
#include <stdint.h>
|
||||
|
||||
#define WARP_SIZE 32
|
||||
const int BLOCK_SIZE = 1024;
|
||||
|
||||
// TODO: Maybe add some fast_sum_f16_f32 variant that not only accumulate in f32
|
||||
@ -49,6 +50,59 @@ fast_sum(const size_t src_numel, const size_t el_to_sum_per_block,
|
||||
dst[dst_id] = shr[0];
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float warp_reduce_sum(float x) {
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||
x += __shfl_xor_sync(0xffffffff, x, mask, 32);
|
||||
}
|
||||
return x;
|
||||
}
|
||||
|
||||
// RmsNorm implementation adapted from ggml, accumulation is made using f32.
|
||||
// https://github.com/ggerganov/llama.cpp/blob/d59bd97065cd7ded6c4ecab54b1d5e0b1b11e318/ggml-cuda.cu#L523
|
||||
template <typename T>
|
||||
__device__ void rmsnorm(const T * x, T * dst, const T * alpha, const int ncols, const float eps) {
|
||||
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
||||
const int tid = threadIdx.x;
|
||||
const int block_size = blockDim.x;
|
||||
|
||||
float tmp = 0.0f; // partial sum for thread in warp
|
||||
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
const float xi = static_cast<float>(x[row*ncols + col]);
|
||||
tmp += xi * xi;
|
||||
}
|
||||
|
||||
// sum up partial sums
|
||||
tmp = warp_reduce_sum(tmp);
|
||||
if (block_size > WARP_SIZE) {
|
||||
__shared__ float s_sum[32];
|
||||
int warp_id = threadIdx.x / WARP_SIZE;
|
||||
int lane_id = threadIdx.x % WARP_SIZE;
|
||||
if (lane_id == 0) {
|
||||
s_sum[warp_id] = tmp;
|
||||
}
|
||||
__syncthreads();
|
||||
tmp = s_sum[lane_id];
|
||||
tmp = warp_reduce_sum(tmp);
|
||||
}
|
||||
|
||||
const float mean = tmp / ncols;
|
||||
const float scale = rsqrtf(mean + eps);
|
||||
|
||||
if (alpha == nullptr) {
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
dst[row*ncols + col] = static_cast<T>(scale * static_cast<float>(x[row*ncols + col]));
|
||||
}
|
||||
}
|
||||
else {
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
float a = static_cast<float>(alpha[col]);
|
||||
dst[row*ncols + col] = static_cast<T>(scale * static_cast<float>(x[row*ncols + col]) * a);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Softmax implementation adapted from ggml.
|
||||
// https://github.com/ggerganov/llama.cpp/blob/d59bd97065cd7ded6c4ecab54b1d5e0b1b11e318/ggml-cuda.cu#L4159
|
||||
template <typename T, typename ACC>
|
||||
@ -341,14 +395,23 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block,
|
||||
softmax<TYPENAME, ACC_TYPENAME>(src, dst, n_cols); \
|
||||
} \
|
||||
|
||||
#define RMSNORM_OP(TYPENAME, FN_NAME) \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
const TYPENAME *src, TYPENAME *dst, const TYPENAME *alpha, \
|
||||
const int n_cols, const float eps) { \
|
||||
rmsnorm<TYPENAME>(src, dst, alpha, n_cols, eps); \
|
||||
} \
|
||||
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
SOFTMAX_OP(__nv_bfloat16, float, softmax_bf16)
|
||||
RMSNORM_OP(__nv_bfloat16, rmsnorm_bf16)
|
||||
SUM_OP(__nv_bfloat16, sum_bf16)
|
||||
FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_argmin_bf16, fast_argmax_bf16, fast_sum_bf16)
|
||||
#endif
|
||||
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
SOFTMAX_OP(__half, float, softmax_f16)
|
||||
RMSNORM_OP(__half, rmsnorm_f16)
|
||||
SUM_OP(__half, sum_f16)
|
||||
FAST_OP(__half, fast_min_f16, fast_max_f16, fast_argmin_f16, fast_argmax_f16, fast_sum_f16)
|
||||
#endif
|
||||
@ -358,6 +421,8 @@ SUM_OP(double, sum_f64)
|
||||
SUM_OP(uint32_t, sum_u32)
|
||||
SOFTMAX_OP(float, float, softmax_f32)
|
||||
SOFTMAX_OP(double, double, softmax_f64)
|
||||
RMSNORM_OP(float, rmsnorm_f32)
|
||||
RMSNORM_OP(double, rmsnorm_f64)
|
||||
|
||||
FAST_OP(float, fast_min_f32, fast_max_f32, fast_argmin_f32, fast_argmax_f32, fast_sum_f32)
|
||||
FAST_OP(double, fast_min_f64, fast_max_f64, fast_argmin_f64, fast_argmax_f64, fast_sum_f64)
|
||||
|
Reference in New Issue
Block a user