mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Improved launch config for layer-norm/rms-norm.
This commit is contained in:
@ -70,10 +70,9 @@ static __device__ __forceinline__ float warp_reduce_sum(float x) {
|
||||
// LayerNorm implementation adapted from ggml, accumulation is made using f32.
|
||||
// https://github.com/ggerganov/llama.cpp/blob/d59bd97065cd7ded6c4ecab54b1d5e0b1b11e318/ggml-cuda.cu#L477
|
||||
template <typename T>
|
||||
__device__ void layernorm(const T * x, T * dst, const T * alpha, const T * beta, const int ncols, const float eps) {
|
||||
__device__ void layernorm(const T * x, T * dst, const T * alpha, const T * beta, const int ncols, const int block_size, const float eps) {
|
||||
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
||||
const int tid = threadIdx.x;
|
||||
const int block_size = blockDim.x;
|
||||
|
||||
float2 mean_var = make_float2(0.f, 0.f);
|
||||
|
||||
@ -134,10 +133,9 @@ __device__ void layernorm(const T * x, T * dst, const T * alpha, const T * beta,
|
||||
// RmsNorm implementation adapted from ggml, accumulation is made using f32.
|
||||
// https://github.com/ggerganov/llama.cpp/blob/d59bd97065cd7ded6c4ecab54b1d5e0b1b11e318/ggml-cuda.cu#L523
|
||||
template <typename T>
|
||||
__device__ void rmsnorm(const T * x, T * dst, const T * alpha, const int ncols, const float eps) {
|
||||
__device__ void rmsnorm(const T * x, T * dst, const T * alpha, const int ncols, const int block_size, const float eps) {
|
||||
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
||||
const int tid = threadIdx.x;
|
||||
const int block_size = blockDim.x;
|
||||
|
||||
float tmp = 0.0f; // partial sum for thread in warp
|
||||
|
||||
@ -530,15 +528,15 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block,
|
||||
#define RMSNORM_OP(TYPENAME, FN_NAME) \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
const TYPENAME *src, TYPENAME *dst, const TYPENAME *alpha, \
|
||||
const int n_cols, const float eps) { \
|
||||
rmsnorm<TYPENAME>(src, dst, alpha, n_cols, eps); \
|
||||
const int n_cols, const int block_size, const float eps) { \
|
||||
rmsnorm<TYPENAME>(src, dst, alpha, n_cols, block_size, eps); \
|
||||
} \
|
||||
|
||||
#define LAYERNORM_OP(TYPENAME, FN_NAME) \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
const TYPENAME *src, TYPENAME *dst, const TYPENAME *alpha, \
|
||||
const TYPENAME *beta, const int n_cols, const float eps) { \
|
||||
layernorm<TYPENAME>(src, dst, alpha, beta, n_cols, eps); \
|
||||
const TYPENAME *beta, const int n_cols, const int block_size, const float eps) { \
|
||||
layernorm<TYPENAME>(src, dst, alpha, beta, n_cols, block_size, eps); \
|
||||
} \
|
||||
|
||||
#define ROPE_OP(TYPENAME, FN_NAME, FN_NAME_I, FN_NAME_THD) \
|
||||
|
@ -543,15 +543,23 @@ impl candle::CustomOp2 for RmsNorm {
|
||||
let dim_m1 = dims[dims.len() - 1];
|
||||
let (n_rows, n_cols) = (el / dim_m1, dim_m1);
|
||||
|
||||
let block_size = if n_cols < 1024 { 32 } else { 1024 };
|
||||
let cfg = LaunchConfig {
|
||||
grid_dim: (n_rows as u32, 1, 1),
|
||||
block_dim: (1024, 1, 1),
|
||||
block_dim: (block_size, 1, 1),
|
||||
shared_mem_bytes: 0,
|
||||
};
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("rmsnorm"), kernels::REDUCE)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
|
||||
let params = (&src, &dst, &alpha, n_cols as i32, self.eps);
|
||||
let params = (
|
||||
&src,
|
||||
&dst,
|
||||
&alpha,
|
||||
n_cols as i32,
|
||||
block_size as i32,
|
||||
self.eps,
|
||||
);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(dst)
|
||||
@ -776,15 +784,24 @@ impl candle::CustomOp3 for LayerNorm {
|
||||
let dim_m1 = dims[dims.len() - 1];
|
||||
let (n_rows, n_cols) = (el / dim_m1, dim_m1);
|
||||
|
||||
let block_size = if n_cols < 1024 { 32 } else { 1024 };
|
||||
let cfg = LaunchConfig {
|
||||
grid_dim: (n_rows as u32, 1, 1),
|
||||
block_dim: (1024, 1, 1),
|
||||
block_dim: (block_size, 1, 1),
|
||||
shared_mem_bytes: 0,
|
||||
};
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("layernorm"), kernels::REDUCE)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
|
||||
let params = (&src, &dst, &alpha, &beta, n_cols as i32, self.eps);
|
||||
let params = (
|
||||
&src,
|
||||
&dst,
|
||||
&alpha,
|
||||
&beta,
|
||||
n_cols as i32,
|
||||
block_size as i32,
|
||||
self.eps,
|
||||
);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(dst)
|
||||
|
Reference in New Issue
Block a user