diff --git a/candle-kernels/src/reduce.cu b/candle-kernels/src/reduce.cu index aaac24a1..079c3708 100644 --- a/candle-kernels/src/reduce.cu +++ b/candle-kernels/src/reduce.cu @@ -70,10 +70,9 @@ static __device__ __forceinline__ float warp_reduce_sum(float x) { // LayerNorm implementation adapted from ggml, accumulation is made using f32. // https://github.com/ggerganov/llama.cpp/blob/d59bd97065cd7ded6c4ecab54b1d5e0b1b11e318/ggml-cuda.cu#L477 template -__device__ void layernorm(const T * x, T * dst, const T * alpha, const T * beta, const int ncols, const float eps) { +__device__ void layernorm(const T * x, T * dst, const T * alpha, const T * beta, const int ncols, const int block_size, const float eps) { const int row = blockIdx.x*blockDim.y + threadIdx.y; const int tid = threadIdx.x; - const int block_size = blockDim.x; float2 mean_var = make_float2(0.f, 0.f); @@ -134,10 +133,9 @@ __device__ void layernorm(const T * x, T * dst, const T * alpha, const T * beta, // RmsNorm implementation adapted from ggml, accumulation is made using f32. // https://github.com/ggerganov/llama.cpp/blob/d59bd97065cd7ded6c4ecab54b1d5e0b1b11e318/ggml-cuda.cu#L523 template -__device__ void rmsnorm(const T * x, T * dst, const T * alpha, const int ncols, const float eps) { +__device__ void rmsnorm(const T * x, T * dst, const T * alpha, const int ncols, const int block_size, const float eps) { const int row = blockIdx.x*blockDim.y + threadIdx.y; const int tid = threadIdx.x; - const int block_size = blockDim.x; float tmp = 0.0f; // partial sum for thread in warp @@ -530,15 +528,15 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block, #define RMSNORM_OP(TYPENAME, FN_NAME) \ extern "C" __global__ void FN_NAME( \ const TYPENAME *src, TYPENAME *dst, const TYPENAME *alpha, \ - const int n_cols, const float eps) { \ - rmsnorm(src, dst, alpha, n_cols, eps); \ + const int n_cols, const int block_size, const float eps) { \ + rmsnorm(src, dst, alpha, n_cols, block_size, eps); \ } \ #define LAYERNORM_OP(TYPENAME, FN_NAME) \ extern "C" __global__ void FN_NAME( \ const TYPENAME *src, TYPENAME *dst, const TYPENAME *alpha, \ - const TYPENAME *beta, const int n_cols, const float eps) { \ - layernorm(src, dst, alpha, beta, n_cols, eps); \ + const TYPENAME *beta, const int n_cols, const int block_size, const float eps) { \ + layernorm(src, dst, alpha, beta, n_cols, block_size, eps); \ } \ #define ROPE_OP(TYPENAME, FN_NAME, FN_NAME_I, FN_NAME_THD) \ diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index 9a360c47..8a3c19fe 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -543,15 +543,23 @@ impl candle::CustomOp2 for RmsNorm { let dim_m1 = dims[dims.len() - 1]; let (n_rows, n_cols) = (el / dim_m1, dim_m1); + let block_size = if n_cols < 1024 { 32 } else { 1024 }; let cfg = LaunchConfig { grid_dim: (n_rows as u32, 1, 1), - block_dim: (1024, 1, 1), + block_dim: (block_size, 1, 1), shared_mem_bytes: 0, }; let func = dev.get_or_load_func(&kernel_name::("rmsnorm"), kernels::REDUCE)?; // SAFETY: Set later by running the kernel. let dst = unsafe { dev.alloc::(el) }.w()?; - let params = (&src, &dst, &alpha, n_cols as i32, self.eps); + let params = ( + &src, + &dst, + &alpha, + n_cols as i32, + block_size as i32, + self.eps, + ); // SAFETY: ffi. unsafe { func.launch(cfg, params) }.w()?; Ok(dst) @@ -776,15 +784,24 @@ impl candle::CustomOp3 for LayerNorm { let dim_m1 = dims[dims.len() - 1]; let (n_rows, n_cols) = (el / dim_m1, dim_m1); + let block_size = if n_cols < 1024 { 32 } else { 1024 }; let cfg = LaunchConfig { grid_dim: (n_rows as u32, 1, 1), - block_dim: (1024, 1, 1), + block_dim: (block_size, 1, 1), shared_mem_bytes: 0, }; let func = dev.get_or_load_func(&kernel_name::("layernorm"), kernels::REDUCE)?; // SAFETY: Set later by running the kernel. let dst = unsafe { dev.alloc::(el) }.w()?; - let params = (&src, &dst, &alpha, &beta, n_cols as i32, self.eps); + let params = ( + &src, + &dst, + &alpha, + &beta, + n_cols as i32, + block_size as i32, + self.eps, + ); // SAFETY: ffi. unsafe { func.launch(cfg, params) }.w()?; Ok(dst)