mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Add the layernorm specialized op. (#2212)
* Add the layernorm cuda kernels. * Dedicated layer norm op. * Add the slower variant. * Plug the cuda implementation. * Add the metal variant. * Add a dedicated test. * Bugfix.
This commit is contained in:
@ -50,6 +50,15 @@ fast_sum(const size_t src_numel, const size_t el_to_sum_per_block,
|
||||
dst[dst_id] = shr[0];
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||
a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32);
|
||||
a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32);
|
||||
}
|
||||
return a;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float warp_reduce_sum(float x) {
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||
@ -58,6 +67,70 @@ static __device__ __forceinline__ float warp_reduce_sum(float x) {
|
||||
return x;
|
||||
}
|
||||
|
||||
// LayerNorm implementation adapted from ggml, accumulation is made using f32.
|
||||
// https://github.com/ggerganov/llama.cpp/blob/d59bd97065cd7ded6c4ecab54b1d5e0b1b11e318/ggml-cuda.cu#L477
|
||||
template <typename T>
|
||||
__device__ void layernorm(const T * x, T * dst, const T * alpha, const T * beta, const int ncols, const float eps) {
|
||||
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
||||
const int tid = threadIdx.x;
|
||||
const int block_size = blockDim.x;
|
||||
|
||||
float2 mean_var = make_float2(0.f, 0.f);
|
||||
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
const float xi = x[row*ncols + col];
|
||||
mean_var.x += xi;
|
||||
mean_var.y += xi * xi;
|
||||
}
|
||||
|
||||
// sum up partial sums
|
||||
mean_var = warp_reduce_sum(mean_var);
|
||||
if (block_size > WARP_SIZE) {
|
||||
__shared__ float2 s_sum[32];
|
||||
int warp_id = threadIdx.x / WARP_SIZE;
|
||||
int lane_id = threadIdx.x % WARP_SIZE;
|
||||
if (lane_id == 0) {
|
||||
s_sum[warp_id] = mean_var;
|
||||
}
|
||||
__syncthreads();
|
||||
mean_var = s_sum[lane_id];
|
||||
mean_var = warp_reduce_sum(mean_var);
|
||||
}
|
||||
|
||||
const float mean = mean_var.x / ncols;
|
||||
const float var = mean_var.y / ncols - mean * mean;
|
||||
const float inv_std = rsqrtf(var + eps);
|
||||
|
||||
if (alpha == nullptr && beta == nullptr) {
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
float lhs = (static_cast<float>(x[row*ncols + col]) - mean) * inv_std;
|
||||
dst[row*ncols + col] = static_cast<T>(lhs);
|
||||
}
|
||||
}
|
||||
else if (alpha == nullptr && beta != nullptr) {
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
float b = static_cast<float>(beta[col]);
|
||||
float lhs = (static_cast<float>(x[row*ncols + col]) - mean) * inv_std;
|
||||
dst[row*ncols + col] = static_cast<T>(lhs + b);
|
||||
}
|
||||
}
|
||||
else if (alpha != nullptr && beta == nullptr) {
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
float a = static_cast<float>(alpha[col]);
|
||||
float lhs = (static_cast<float>(x[row*ncols + col]) - mean) * inv_std;
|
||||
dst[row*ncols + col] = static_cast<T>(lhs * a);
|
||||
}
|
||||
}
|
||||
else {
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
float a = static_cast<float>(alpha[col]);
|
||||
float b = static_cast<float>(beta[col]);
|
||||
float lhs = (static_cast<float>(x[row*ncols + col]) - mean) * inv_std;
|
||||
dst[row*ncols + col] = static_cast<T>(lhs * a + b);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RmsNorm implementation adapted from ggml, accumulation is made using f32.
|
||||
// https://github.com/ggerganov/llama.cpp/blob/d59bd97065cd7ded6c4ecab54b1d5e0b1b11e318/ggml-cuda.cu#L523
|
||||
template <typename T>
|
||||
@ -461,6 +534,13 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block,
|
||||
rmsnorm<TYPENAME>(src, dst, alpha, n_cols, eps); \
|
||||
} \
|
||||
|
||||
#define LAYERNORM_OP(TYPENAME, FN_NAME) \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
const TYPENAME *src, TYPENAME *dst, const TYPENAME *alpha, \
|
||||
const TYPENAME *beta, const int n_cols, const float eps) { \
|
||||
layernorm<TYPENAME>(src, dst, alpha, beta, n_cols, eps); \
|
||||
} \
|
||||
|
||||
#define ROPE_OP(TYPENAME, FN_NAME, FN_NAME_I, FN_NAME_THD) \
|
||||
extern "C" __global__ void FN_NAME_I( \
|
||||
const TYPENAME *src, \
|
||||
@ -496,6 +576,7 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block,
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
SOFTMAX_OP(__nv_bfloat16, float, softmax_bf16)
|
||||
RMSNORM_OP(__nv_bfloat16, rmsnorm_bf16)
|
||||
LAYERNORM_OP(__nv_bfloat16, layernorm_bf16)
|
||||
ROPE_OP(__nv_bfloat16, rope_bf16, rope_i_bf16, rope_thd_bf16)
|
||||
SUM_OP(__nv_bfloat16, sum_bf16)
|
||||
FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_argmin_bf16, fast_argmax_bf16, fast_sum_bf16)
|
||||
@ -504,6 +585,7 @@ FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_argmin_bf16, fast_argm
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
SOFTMAX_OP(__half, float, softmax_f16)
|
||||
RMSNORM_OP(__half, rmsnorm_f16)
|
||||
LAYERNORM_OP(__half, layernorm_f16)
|
||||
ROPE_OP(__half, rope_f16, rope_i_f16, rope_thd_f16)
|
||||
SUM_OP(__half, sum_f16)
|
||||
FAST_OP(__half, fast_min_f16, fast_max_f16, fast_argmin_f16, fast_argmax_f16, fast_sum_f16)
|
||||
@ -516,6 +598,8 @@ SOFTMAX_OP(float, float, softmax_f32)
|
||||
SOFTMAX_OP(double, double, softmax_f64)
|
||||
RMSNORM_OP(float, rmsnorm_f32)
|
||||
RMSNORM_OP(double, rmsnorm_f64)
|
||||
LAYERNORM_OP(float, layernorm_f32)
|
||||
LAYERNORM_OP(double, layernorm_f64)
|
||||
ROPE_OP(float, rope_f32, rope_i_f32, rope_thd_f32)
|
||||
ROPE_OP(double, rope_f64, rope_i_f64, rope_thd_f64)
|
||||
|
||||
|
Reference in New Issue
Block a user