diff --git a/candle-kernels/README.md b/candle-kernels/README.md index a527dde6..1043f31f 100644 --- a/candle-kernels/README.md +++ b/candle-kernels/README.md @@ -2,7 +2,3 @@ This crate contains CUDA kernels used from candle. Some of these implementations come from the [dfdx crate](https://github.com/coreylowman/dfdx). - -The `ln*` files come from the [flash-attention -repo](https://github.com/Dao-AILab/flash-attention) and have been edited so as -to compile without including the PyTorch codebase. diff --git a/candle-kernels/build.rs b/candle-kernels/build.rs index 58518412..3c8e96a9 100644 --- a/candle-kernels/build.rs +++ b/candle-kernels/build.rs @@ -184,7 +184,6 @@ mod cuda { let mut command = std::process::Command::new("nvcc"); command.arg(format!("--gpu-architecture=sm_{compute_cap}")) .arg("--ptx") - .arg("--expt-relaxed-constexpr") .args(["--default-stream", "per-thread"]) .args(["--output-directory", &out_dir]) // Flash attention only diff --git a/candle-kernels/src/lib.rs b/candle-kernels/src/lib.rs index 848daee5..b9d12b7b 100644 --- a/candle-kernels/src/lib.rs +++ b/candle-kernels/src/lib.rs @@ -4,7 +4,6 @@ pub const CAST: &str = include_str!(concat!(env!("OUT_DIR"), "/cast.ptx")); pub const CONV: &str = include_str!(concat!(env!("OUT_DIR"), "/conv.ptx")); pub const EMBEDDINGS: &str = include_str!(concat!(env!("OUT_DIR"), "/embeddings.ptx")); pub const FILL: &str = include_str!(concat!(env!("OUT_DIR"), "/fill.ptx")); -pub const LN_FWD_256: &str = include_str!(concat!(env!("OUT_DIR"), "/ln_fwd_256.ptx")); pub const REDUCE: &str = include_str!(concat!(env!("OUT_DIR"), "/reduce.ptx")); pub const TERNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/ternary.ptx")); pub const UNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/unary.ptx")); diff --git a/candle-kernels/src/ln.h b/candle-kernels/src/ln.h deleted file mode 100644 index 3acf18ec..00000000 --- a/candle-kernels/src/ln.h +++ /dev/null @@ -1,274 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include - -namespace layer_norm { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct LaunchParams{ - - size_t elts_per_thread; - size_t workspace_bytes; - size_t barrier_size; - - cudaDeviceProp * props; - - cudaStream_t stream; - - Params params; - -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct ParamsBase { - ParamsBase() - : ctas_per_col(0) - , rows(0) - , cols(0) - , x(nullptr) - , mu(nullptr) - , rs(nullptr) - , gamma(nullptr) - , gamma1(nullptr) - , rowscale(nullptr) - , colscale(nullptr) - , dropout_keep_p(1.f) - , dropout_scale(1.f) - , is_rms_norm(false) - , workspace(nullptr) - , barrier(nullptr) - { - } - - // For Multi-CTA, number of different CTA groups. Otherwise same as gridDim.x. - int ctas_per_col; - - // Input is interpreted as matrix. We normalize across columns. - int rows; - int cols; - - // Common data pointers. - void *x0; - void *x1; - void *residual; - void *x; - void *dmask; - void *dmask1; - void *mu; - void *rs; - void *gamma; - void *gamma1; - void *rowscale; - void *colscale; - void *x0_subset; - void *z_subset; - - float inverse_cols; - - float dropout_keep_p; - float dropout_scale; - float rowscale_const; - - bool is_rms_norm; - - // Multi-CTA workspace in gmem. - void *workspace; - - // Multi-CTA sync barriers in gmem. - int *barrier; - -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct FwdParams : public ParamsBase { - FwdParams() - : ParamsBase() - , z(nullptr) - , z1(nullptr) - , beta(nullptr) - , beta1(nullptr) - , epsilon(0.f) - { - } - - // Output of LN FWD. - void *z; - void *z1; - void *beta; - void *beta1; - float epsilon; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct BwdParams : public ParamsBase { - BwdParams() - : ParamsBase() - , dz(nullptr) - , dz1(nullptr) - , dx(nullptr) - , dbeta_part(nullptr) - , dgamma_part(nullptr) - , dbeta1_part(nullptr) - , dgamma1_part(nullptr) - , dcolscale_part(nullptr) - , dx0(nullptr) - , dx1(nullptr) - , dresidual(nullptr) - , dbeta(nullptr) - , dgamma(nullptr) - , dbeta1(nullptr) - , dgamma1(nullptr) - , dcolscale(nullptr) - { - } - - // Input: gradient wrt. LN FWD output. - void *dz; - void *dz1; - // Input: gradient wrt residual. - void *dx; - - // Workspace for Wgrad pre-reduction. - void *dbeta_part; - void *dgamma_part; - void *dbeta1_part; - void *dgamma1_part; - void *dcolscale_part; - - // Output: Dgrad. - void *dx0; - void *dx1; - void *dresidual; - // Output: Wgrad. - void *dbeta; - void *dgamma; - void *dbeta1; - void *dgamma1; - void *dcolscale; - -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -using FwdFunction = std::function&, const bool)>; -using BwdFunction = std::function&, const bool)>; -using FunctionKey = uint64_t; -using FwdRegistry = std::unordered_map; -using BwdRegistry = std::unordered_map; - -extern FwdRegistry FWD_FUNCS, PARALLEL_FWD_FUNCS; -extern BwdRegistry BWD_FUNCS, PARALLEL_BWD_FUNCS; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -using fp32 = float; -using fp16 = half; -using bf16 = nv_bfloat16; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct TypeId{}; - -template<> -struct TypeId{ - constexpr static uint32_t Value = 0; -}; - -template<> -struct TypeId{ - constexpr static uint32_t Value = 1; -}; - -template<> -struct TypeId{ - constexpr static uint32_t Value = 2; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Type2Key{ - constexpr static uint32_t Value = TypeId::Value << S; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct WeightType2Key : public Type2Key{}; - -template -struct InputType2Key : public Type2Key{}; - -template -struct ResidualType2Key : public Type2Key{}; - -template -struct OutputType2Key : public Type2Key{}; - -template -struct ComputeType2Key : public Type2Key{}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Types2Key{ - constexpr static uint32_t Value = WeightType2Key::Value | InputType2Key::Value | ResidualType2Key::Value | OutputType2Key::Value | ComputeType2Key::Value; - constexpr static inline uint64_t get(const uint64_t hidden_size){ - constexpr uint64_t type_key = Value; - return (type_key << 32) | hidden_size; - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct FwdRegistrar{ - FwdRegistrar(FwdFunction f){ - uint64_t key = Types2Key::get(HIDDEN_SIZE); - FWD_FUNCS.insert({ key, f }); - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct BwdRegistrar{ - BwdRegistrar(BwdFunction f){ - uint64_t key = Types2Key::get(HIDDEN_SIZE); - BWD_FUNCS.insert({ key, f }); - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct FwdParallelRegistrar{ - FwdParallelRegistrar(FwdFunction f){ - uint64_t key = Types2Key::get(HIDDEN_SIZE); - PARALLEL_FWD_FUNCS.insert({ key, f }); - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct BwdParallelRegistrar{ - BwdParallelRegistrar(BwdFunction f){ - uint64_t key = Types2Key::get(HIDDEN_SIZE); - PARALLEL_BWD_FUNCS.insert({ key, f }); - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace layer_norm diff --git a/candle-kernels/src/ln_fwd_256.cu b/candle-kernels/src/ln_fwd_256.cu deleted file mode 100644 index f3a541c6..00000000 --- a/candle-kernels/src/ln_fwd_256.cu +++ /dev/null @@ -1,15 +0,0 @@ -#include "ln_fwd_kernels.cuh" - -// Create forward launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG - -REGISTER_FWD_LAUNCHER( 256, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 256, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 256, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 256, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 256, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 256, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 256, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 256, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 256, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 256, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); diff --git a/candle-kernels/src/ln_fwd_kernels.cuh b/candle-kernels/src/ln_fwd_kernels.cuh deleted file mode 100644 index faa64d05..00000000 --- a/candle-kernels/src/ln_fwd_kernels.cuh +++ /dev/null @@ -1,257 +0,0 @@ -#pragma once - -#include - -#include "ln.h" -#include "ln_utils.cuh" -#include "ln_kernel_traits.h" -#include "static_switch.h" - -namespace layer_norm { - -template -__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) -void ln_fwd_kernel(FwdParams params) { - - enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA }; - enum { WARPS_N = Ktraits::WARPS_N }; - enum { WARPS_M = Ktraits::WARPS_M }; - enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW }; - enum { VEC_COLS_PER_LDG = Ktraits::VEC_COLS_PER_LDG }; - enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW }; - enum { LDGS = Ktraits::LDGS }; - enum { NUM_ELTS = Ktraits::NUM_ELTS }; - enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW }; - - using input_t = typename Ktraits::input_t; - using residual_t = typename Ktraits::residual_t; - using output_t = typename Ktraits::output_t; - using index_t = typename Ktraits::index_t; - using compute_t = typename Ktraits::compute_t; - using mask_t = typename Ktraits::mask_t; - using Ivec = typename Ktraits::Ivec; - using Rvec = typename Ktraits::Rvec; - using Ovec = typename Ktraits::Ovec; - using Wvec = typename Ktraits::Wvec; - using Cvec = typename Ktraits::Cvec; - using Mvec = typename Ktraits::Mvec; - - using Stats = typename Ktraits::Stats; - using stats_t = typename Stats::stats_t; - - const bool has_residual = params.residual != nullptr; - const bool save_x = has_residual || Is_dropout || Has_colscale || (params.rowscale != nullptr) || Has_subset || !(std::is_same::value); - - extern __shared__ char smem_[]; - - const index_t tidx = threadIdx.x; - const index_t bidn = blockIdx.x % CTAS_PER_ROW; - const index_t bidm = blockIdx.x / CTAS_PER_ROW; - const index_t lane = tidx % THREADS_PER_WARP; - const index_t warp = tidx / THREADS_PER_WARP; - const index_t warp_m = warp / WARPS_N; - const index_t warp_n = warp % WARPS_N; - - const index_t r = bidm * ROWS_PER_CTA + warp_m; - const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane; - - Stats stats(params, bidm, bidn, warp_m, warp_n, lane, smem_); - - compute_t *mu_ptr = static_cast(params.mu); - compute_t *rs_ptr = static_cast(params.rs); - - const input_t *rowscale = static_cast(params.rowscale); - const index_t *x0_subset = static_cast(params.x0_subset); - const index_t *z_subset = static_cast(params.z_subset); - - const index_t num_valid_ldgs = ((params.cols / Ktraits::ELTS_PER_LDG) - 1 - c + VEC_COLS_PER_LDG) / VEC_COLS_PER_LDG; - - Wvec gamma[LDGS]; - Wvec beta[LDGS]; - Wvec colscale[LDGS]; - index_t idx = c; - #pragma unroll - for( int it = 0; it < LDGS; it++ ) { - if (Is_even_cols || (it < num_valid_ldgs)) { - gamma[it].load_from(params.gamma, idx); - if (params.beta != nullptr) { - beta[it].load_from(params.beta, idx); - } else { - beta[it].zero_(); - } - if (Has_colscale) { colscale[it].load_from(params.colscale, idx); } - idx += VEC_COLS_PER_LDG; - } - } - - for( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) { - const compute_t rowscale_val = !Has_subset ? (params.rowscale == nullptr ? 1.0f : compute_t(rowscale[row])) : params.rowscale_const; - const int row_x0 = !Has_subset ? row + 1 : x0_subset[row]; - const int row_z = !Has_subset ? row + 1 : z_subset[row]; - const bool load_x0 = !Has_subset || row_x0 > 0; - index_t idx_x = row * params.cols / Ktraits::ELTS_PER_LDG + c; - index_t idx_x0 = !Has_subset ? idx_x : (load_x0 ? (row_x0 - 1) * params.cols / Ktraits::ELTS_PER_LDG + c : 0); - compute_t xf[LDGS * NUM_ELTS]; - #pragma unroll - for( int it = 0; it < LDGS; it++ ) { - if (Is_even_cols || (it < num_valid_ldgs)) { - Ivec x0; - Rvec residual; - Rvec x; - Mvec dmask; - if (load_x0) { x0.load_from(params.x0, !Has_subset ? idx_x : idx_x0); } - if (has_residual) { residual.load_from(params.residual, idx_x); } - #pragma unroll - for( int jt = 0; jt < NUM_ELTS; jt++ ) { - // TD [2022-04-22]: We're memory bound, not compute bound, so we don't need to use - // the more efficient curand_uniform4. - compute_t x_ij; - if (load_x0) { - mask_t keep = true; - if (Is_dropout) { dmask.data.elt[jt] = keep; } - compute_t x0_ij = compute_t(x0.data.elt[jt]) * rowscale_val; - x0_ij = keep ? (Is_dropout ? x0_ij * params.dropout_scale : x0_ij) : 0.0f; - if (Has_colscale) { x0_ij *= compute_t(colscale[it].data.elt[jt]); } - x_ij = has_residual ? x0_ij + compute_t(residual.data.elt[jt]) : x0_ij; - } else { - x_ij = has_residual ? compute_t(residual.data.elt[jt]) : 0.f; - } - if (save_x) { x.data.elt[jt] = x_ij; } - xf[it * NUM_ELTS + jt] = x_ij; - } - if (save_x) { x.store_to(params.x, idx_x); } - if (Is_dropout && load_x0) { dmask.store_to(params.dmask, !Has_subset ? idx_x : idx_x0); } - idx_x += VEC_COLS_PER_LDG; - idx_x0 += VEC_COLS_PER_LDG; - } - } - - static_assert(CTAS_PER_ROW == 1, "Don't support multiple CTAs per row for now"); - const index_t num_vecs = params.cols / Ktraits::ELTS_PER_LDG; - const index_t num_full_ldgs = num_vecs / Ktraits::VEC_COLS_PER_LDG; - const index_t remaining_vecs = num_vecs % Ktraits::VEC_COLS_PER_LDG; - auto valid_elts_in_warp_fn = [num_full_ldgs, remaining_vecs] (int warp_n) -> int { - // Need to convert to int, otherwise the subtraction will wrap around. - const index_t valid_partial_vecs_in_warp = - std::min(std::max(int(remaining_vecs) - int(warp_n * THREADS_PER_WARP), int(0)), - int(THREADS_PER_WARP)); - return (num_full_ldgs * THREADS_PER_WARP + valid_partial_vecs_in_warp) * NUM_ELTS; - }; - stats_t s = stats.template compute( - xf, params.inverse_cols, valid_elts_in_warp_fn, num_valid_ldgs * NUM_ELTS - ); - - compute_t mu = layer_norm::Get<0>::of(s); - compute_t m2 = layer_norm::Get<1>::of(s); - - if( bidn == 0 && warp_n == 0 && lane == 0 ) { - mu_ptr[row] = mu; - } - - compute_t rs = rsqrtf(m2 * params.inverse_cols + params.epsilon + (!params.is_rms_norm ? 0.f : mu * mu)); - - if( bidn == 0 && warp_n == 0 && lane == 0 ) { - rs_ptr[row] = rs; - } - - const bool save_z = !Has_subset || row_z > 0; - if (save_z) { - index_t idx_z = (!Has_subset ? row : (row_z - 1)) * params.cols / Ktraits::ELTS_PER_LDG + c; - #pragma unroll - for( int it = 0; it < LDGS; it++ ) { - if (Is_even_cols || (it < num_valid_ldgs)) { - Ovec z; - #pragma unroll - for( int jt = 0; jt < NUM_ELTS; jt++ ) { - compute_t y_ij = compute_t(rs * (xf[it * NUM_ELTS + jt] - (!params.is_rms_norm ? mu : 0.f))); - compute_t g_ij = gamma[it].data.elt[jt]; - compute_t b_ij = beta[it].data.elt[jt]; - z.data.elt[jt] = output_t(g_ij * y_ij + b_ij); - } - z.store_to(params.z, idx_z); - idx_z += VEC_COLS_PER_LDG; - } - } - } - - } -} - -} // namespace layer_norm - -using namespace layer_norm; - -template< - typename weight_t, - typename input_t, - typename residual_t, - typename output_t, - typename compute_t, - typename index_t, - int HIDDEN_SIZE, - int CTAS_PER_ROW, - int WARPS_M, - int WARPS_N, - int BYTES_PER_LDG -> -void launch_(LaunchParams &launch_params, const bool configure_params){ - - using Kernel_traits = Kernel_traits; - bool has_colscale = launch_params.params.colscale != nullptr; - bool has_subset = launch_params.params.x0_subset != nullptr; - bool is_even_cols = launch_params.params.cols == HIDDEN_SIZE; - BOOL_SWITCH(launch_params.params.dropout_keep_p < 1.f, IsDropoutConst, [&] { - BOOL_SWITCH(has_colscale, HasColscaleConst, [&] { - BOOL_SWITCH(has_subset, HasSubsetConst, [&] { - BOOL_SWITCH(is_even_cols, IsEvenColsConst, [&] { - auto kernel = &ln_fwd_kernel; - if( configure_params ) { - int ctas_per_sm; - CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD)); - launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW; - const size_t rows_per_loop = launch_params.params.ctas_per_col * Kernel_traits::ROWS_PER_CTA; - launch_params.elts_per_thread = (launch_params.params.rows + rows_per_loop - 1) / rows_per_loop * Kernel_traits::LDGS * Kernel_traits::NUM_ELTS; - launch_params.barrier_size = 0; - launch_params.workspace_bytes = 0; - if(Kernel_traits::CTAS_PER_ROW > 1) { - launch_params.barrier_size = 2 * launch_params.params.ctas_per_col; - launch_params.workspace_bytes = launch_params.params.ctas_per_col - * Kernel_traits::WARPS_M - * Kernel_traits::CTAS_PER_ROW - * sizeof(typename Kernel_traits::Stats::stats_t) - * 2; - } - return; - } - - if( Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024 ) { - CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES_FWD)); - } - auto stream = launch_params.stream; - auto ctas_per_col = launch_params.params.ctas_per_col; - - if( Kernel_traits::CTAS_PER_ROW == 1 ) { - kernel<<>>(launch_params.params); - } else { - dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col); - dim3 block(Kernel_traits::THREADS_PER_CTA); - void *params_ = (void *)&launch_params.params; - cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)¶ms_, Kernel_traits::SMEM_BYTES_FWD, stream); - } - }); - }); - }); - }); -} diff --git a/candle-kernels/src/ln_kernel_traits.h b/candle-kernels/src/ln_kernel_traits.h deleted file mode 100644 index 77de6bf9..00000000 --- a/candle-kernels/src/ln_kernel_traits.h +++ /dev/null @@ -1,172 +0,0 @@ -#pragma once - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace layer_norm { -template< - uint32_t HIDDEN_SIZE_, - typename weight_t_, - typename input_t_, - typename residual_t_, - typename output_t_, - typename compute_t_, - typename index_t_, - uint32_t THREADS_PER_CTA_ -> -struct Kernel_traits_base { - - using weight_t = weight_t_; - using input_t = input_t_; - using residual_t = residual_t_; - using output_t = output_t_; - using compute_t = compute_t_; - using index_t = index_t_; - - enum { HIDDEN_SIZE = HIDDEN_SIZE_ }; - enum { THREADS_PER_CTA = THREADS_PER_CTA_ }; - enum { THREADS_PER_WARP = 32 }; - -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< - uint32_t HIDDEN_SIZE_, - typename weight_t_, - typename input_t_, - typename residual_t_, - typename output_t_, - typename compute_t_, - typename index_t_, - bool Has_colscale, - uint32_t THREADS_PER_CTA_, - uint32_t BYTES_PER_LDG_, - typename Base = Kernel_traits_base -> -struct Kernel_traits_finalize : public Base { - enum { ROWS_PER_CTA = Base::THREADS_PER_CTA / Base::THREADS_PER_WARP }; - static_assert((int) ROWS_PER_CTA <= (int) Base::THREADS_PER_WARP); - // Bytes per global load from the input. - enum { BYTES_PER_LDG = BYTES_PER_LDG_ }; - // Number of elements fetched by a global load. - enum { ELTS_PER_LDG = BYTES_PER_LDG / sizeof(compute_t_) }; - // Bytes per global store of the weights. - enum { BYTES_PER_STG = ELTS_PER_LDG * sizeof(weight_t_) }; - static_assert(sizeof(BYTES_PER_LDG) == 4, "Conflict-free smem transpose only implemented for 4B compute type!"); - static_assert(Base::THREADS_PER_CTA == ROWS_PER_CTA * Base::THREADS_PER_WARP, "We assume one warp per row!"); - // The total number of BYTES_PER_LDG-wide words in a hidden vector. - enum { COLS = HIDDEN_SIZE_ * sizeof(compute_t_) / BYTES_PER_LDG }; - static_assert(COLS * BYTES_PER_LDG == HIDDEN_SIZE_ * sizeof(compute_t_)); - - // Shared memory size to transpose the CTA result. - enum { SMEM_BYTES_TRANSPOSE = Base::THREADS_PER_CTA * BYTES_PER_LDG }; - // Shared memory size to coalsece the CTA result. - enum { SMEM_BYTES_OUTPUT = Base::THREADS_PER_WARP * BYTES_PER_LDG }; - // Shared memory requirement per CTA. - static constexpr int NUM_FACTORS = Has_colscale ? 3 : 2; - enum { SMEM_BYTES_PER_CTA = NUM_FACTORS * SMEM_BYTES_TRANSPOSE + NUM_FACTORS * SMEM_BYTES_OUTPUT }; - - // The type of the reducer. - using Reducer = layer_norm::Reducer; - - // Condition for the whole CTA to participate in syncthreads. - static_assert(COLS % Base::THREADS_PER_WARP == 0); - enum { CTAS = COLS / Base::THREADS_PER_WARP }; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - -template< - typename weight_t_, - typename input_t_, - typename residual_t_, - typename output_t_, - typename compute_t_, - typename index_t_, - uint32_t HIDDEN_SIZE_, - uint32_t CTAS_PER_ROW_, - uint32_t WARPS_M_, - uint32_t WARPS_N_, - uint32_t BYTES_PER_LDG_ = 16, - typename Base = Kernel_traits_base< - HIDDEN_SIZE_, - weight_t_, - input_t_, - residual_t_, - output_t_, - compute_t_, - index_t_, - WARPS_M_*WARPS_N_*THREADS_PER_WARP - > -> -struct Kernel_traits : public Base { - - using input_t = typename Base::input_t; - using residual_t = typename Base::residual_t; - using weight_t = typename Base::weight_t; - using compute_t = typename Base::compute_t; - using output_t = typename Base::output_t; - using index_t = typename Base::index_t; - // using mask_t = unsigned char; - using mask_t = bool; - - enum { CTAS_PER_ROW = CTAS_PER_ROW_ }; - enum { WARPS_M = WARPS_M_ }; - enum { WARPS_N = WARPS_N_ }; - enum { COLS = HIDDEN_SIZE_ }; - enum { HIDDEN_SIZE = HIDDEN_SIZE_ }; - enum { BYTES_PER_LDG = BYTES_PER_LDG_ }; - enum { NUM_ELTS = BYTES_PER_LDG / sizeof(input_t) }; - - enum { THREADS_PER_ROW = WARPS_N * THREADS_PER_WARP }; - enum { THREADS_PER_CTA = WARPS_M * THREADS_PER_ROW }; - enum { ROWS_PER_CTA = WARPS_M }; - - enum { BYTES_PER_ROW = COLS * sizeof(input_t) }; - enum { BYTES_PER_ROW_PER_CTA = THREADS_PER_ROW * BYTES_PER_LDG }; - // Multi-row per CTA not supported for multi-CTA => no smem for WGRAD needed - enum { SMEM_BYTES_WGRAD = CTAS_PER_ROW > 1 ? 0 : ROWS_PER_CTA * COLS * sizeof(compute_t) }; - static_assert(WARPS_M == 1 || CTAS_PER_ROW == 1); - - using reduce_t = typename layer_norm::TypeToVec2::Type; - using Reducer = layer_norm::Reducer; - - enum { SMEM_BYTES_DGRAD = Reducer::SMEM_BYTES }; - enum { SMEM_BYTES = SMEM_BYTES_DGRAD + SMEM_BYTES_WGRAD }; - - using Ivec = layer_norm::Vec; - using Rvec = layer_norm::Vec; - using Ovec = layer_norm::Vec; - using Wvec = layer_norm::Vec; - using Cvec = layer_norm::Vec; - using Mvec = layer_norm::Vec; - enum { ELTS_PER_LDG = BYTES_PER_LDG / sizeof(input_t) }; - - // Assume that each thread can handle the same number of elements in the output and weights as in the input. - static_assert(sizeof(input_t) == sizeof(output_t)); - static_assert(sizeof(input_t) <= sizeof(residual_t)); - // The number of columns fetched per load from input: one per thread. - enum { VEC_COLS_PER_LDG = CTAS_PER_ROW * THREADS_PER_ROW }; - // The total number of vectorized loads/stores per hidden vector. - enum { VEC_COLS = COLS / ELTS_PER_LDG }; - // The number of loads per thread for the input. - enum { LDGS = VEC_COLS / VEC_COLS_PER_LDG }; - static_assert(LDGS * VEC_COLS_PER_LDG == VEC_COLS); - //static_assert(LDGS * BYTES_PER_ROW_PER_CTA * CTAS_PER_ROW == BYTES_PER_ROW, ""); - - using Stats = layer_norm::Stats; - enum { SMEM_BYTES_FWD = Stats::SMEM_BYTES }; - -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace layer_norm diff --git a/candle-kernels/src/ln_utils.cuh b/candle-kernels/src/ln_utils.cuh deleted file mode 100644 index 178d6fda..00000000 --- a/candle-kernels/src/ln_utils.cuh +++ /dev/null @@ -1,783 +0,0 @@ -#pragma once - -#include - -#include -#include - -#include "ln.h" - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -constexpr uint32_t THREADS_PER_WARP = 32; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline void check_cuda_(cudaError_t status, const char *file, int line) { - if( status != cudaSuccess ) { - fprintf(stderr, "CUDA Error: %s %s %d\n", cudaGetErrorString(status), file, line); - exit(status); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#define CHECK_CUDA(ans) \ - { check_cuda_((ans), __FILE__, __LINE__); } - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#define DIVUP(x, y) (((x) + ((y)-1)) / (y)) - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#define REGISTER_FWD_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG) \ - void ln_fwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE(LaunchParams &launch_params, \ - const bool configure_params) { \ - launch_( \ - launch_params, configure_params); \ - } \ - static FwdRegistrar reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE( \ - ln_fwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE) - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#define REGISTER_BWD_LAUNCHER( \ - HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE) \ - void ln_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE(LaunchParams &launch_params, \ - const bool configure_params) { \ - launch_(launch_params, configure_params); \ - } \ - static BwdRegistrar reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE( \ - ln_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE) - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#define REGISTER_PARALLEL_FWD_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG) \ - void ln_parallel_residual_fwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE(LaunchParams &launch_params, \ - const bool configure_params) { \ - launch_parallel_residual_( \ - launch_params, configure_params); \ - } \ - static FwdParallelRegistrar reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE( \ - ln_parallel_residual_fwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE) - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#define REGISTER_PARALLEL_BWD_LAUNCHER( \ - HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE) \ - void ln_parallel_residual_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE(LaunchParams &launch_params, \ - const bool configure_params) { \ - launch_parallel_residual_(launch_params, configure_params); \ - } \ - static BwdParallelRegistrar reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE( \ - ln_parallel_residual_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE) - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 operator+(const float2 & a, const float2 & b){ - return {a.x + b.x, a.y + b.y}; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void operator+=(float2 & a, const float2 & b){ - a.x += b.x; - a.y += b.y; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Sum { - inline __device__ Sum(){} - inline __device__ T operator()(const T &a, const T &b){ - return a + b; - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ T warp_shuffle_xor(const T & x, uint32_t idx){ - return __shfl_xor_sync(uint32_t(-1), x, idx); -} - -template<> -inline __device__ float2 warp_shuffle_xor(const float2 & x, uint32_t idx){ - return { warp_shuffle_xor(x.x, idx), warp_shuffle_xor(x.y, idx) }; -} - -template -inline __device__ T warp_shuffle_down(const T & x, uint32_t idx){ - return __shfl_down_sync(uint32_t(-1), x, idx); -} - -template<> -inline __device__ float2 warp_shuffle_down(const float2 & x, uint32_t idx){ - return { warp_shuffle_down(x.x, idx), warp_shuffle_down(x.y, idx) }; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace layer_norm { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct uint16 { - uint4 u; - uint4 v; - uint4 s; - uint4 t; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct uint8 { - uint4 u; - uint4 v; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct BytesToType {}; - -template<> -struct BytesToType<64> { - using Type = uint16; - static_assert(sizeof(Type) == 64); -}; - -template<> -struct BytesToType<32> { - using Type = uint8; - static_assert(sizeof(Type) == 32); -}; - -template<> -struct BytesToType<16> { - using Type = uint4; - static_assert(sizeof(Type) == 16); -}; - -template<> -struct BytesToType<8> { - using Type = uint64_t; - static_assert(sizeof(Type) == 8); -}; - -template<> -struct BytesToType<4> { - using Type = uint32_t; - static_assert(sizeof(Type) == 4); -}; - -template<> -struct BytesToType<2> { - using Type = uint16_t; - static_assert(sizeof(Type) == 2); -}; - -template<> -struct BytesToType<1> { - using Type = uint8_t; - static_assert(sizeof(Type) == 1); -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct TypeToVec2 {}; - -template<> -struct TypeToVec2 { - using Type = float2; -}; - -template<> -struct TypeToVec2 { - using Type = half2; -}; - -template<> -struct TypeToVec2 { - using Type = nv_bfloat162; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Get { - template - static inline __device__ R of(const T &vec); -}; - -template<> -template -inline __device__ R Get<0>::of(const T &vec) { - return vec.x; -} - -template<> -template -inline __device__ R Get<1>::of(const T &vec) { - return vec.y; -} - -template<> -template -inline __device__ R Get<2>::of(const T &vec) { - return vec.z; -} - -template<> -template -inline __device__ R Get<3>::of(const T &vec) { - return vec.w; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Converter{ - static inline __device__ Dst convert(const Src &from) { - return Dst(from); - } -}; - -template<> -struct Converter{ - static inline __device__ half2 convert(const float2 &x) { - return __float22half2_rn(x); - } -}; - -template<> -struct Converter{ - static inline __device__ nv_bfloat162 convert(const float2 &x) { -#if __CUDA_ARCH__ >= 800 - return __float22bfloat162_rn(x); -#else - union { - nv_bfloat162 raw; - nv_bfloat16 x; - nv_bfloat16 y; - } tmp; - tmp.x = __float2bfloat16_rn(x.x); - tmp.y = __float2bfloat16_rn(x.y); - return tmp.raw; -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Zeros{ - static inline __device__ T get() { - return T(0.f); - } -}; - -template<> -struct Zeros{ - static inline __device__ float2 get() { - return make_float2(0.f, 0.f); - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Vec { - - enum { BYTES = NUM_ELT * sizeof(Elt_type) }; - - using Vec_type = typename BytesToType::Type; - - using Alias_type = union { - Vec_type vec; - Elt_type elt[NUM_ELT]; - }; - - Alias_type data; - - template - inline __device__ void to(Vec &other) { - #pragma unroll - for( int it = 0; it < NUM_ELT; it++ ) { - other.data.elt[it] = S(this->data.elt[it]); - } - } - - template - inline __device__ void assign(const Op &op) { - #pragma unroll - for( int it = 0; it < NUM_ELT; it++ ) { - this->data.elt[it] = op(it); - } - } - - inline __device__ void zero_() { - #pragma unroll - for( int it = 0; it < NUM_ELT; it++ ) { - this->data.elt[it] = Elt_type(0.f); - } - } - - inline __device__ void load_from(const void *base_ptr, const size_t idx) { - this->data.vec = static_cast(base_ptr)[idx]; - } - - inline __device__ void store_to(void *base_ptr, const size_t idx) { - static_cast(base_ptr)[idx] = this->data.vec; - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct InterCTASync { - - template - inline __device__ InterCTASync(Params & params, uint32_t bidm, uint32_t bidn) - : phase_counter_(0) - , b0_(params.barrier + bidm) // The barrier for this group of CTAs. - , b1_(params.barrier + bidm + params.ctas_per_col) // The barrier for this group of CTAs. - { - // BARRIERS ARE ASSUMED TO BE INITIALIZED TO 0! - } - - inline __device__ void spin_wait_(int *barrier, int step, int expected) { - asm volatile("red.release.gpu.global.add.s32 [%0], %1;" ::"l"(barrier), "r"(step)); - for( int found = -1; found != expected; ) { - asm volatile("ld.global.acquire.gpu.b32 %0, [%1];" : "=r"(found) : "l"(barrier)); - } - } - - inline __device__ void sync(){ - // ALL THREADS MUST ENTER! - - // We switch barrier every iteration. - int *barrier = phase_counter_ & 0x1 ? b1_ : b0_; - // We decrement every other iteration. - bool dec = phase_counter_ & 0x2; - int step = dec ? -1 : 1; - int expected = dec ? 0 : CTAS_PER_ROW; - // There are only 4 phases: up/down for b0/b1. - phase_counter_ = (phase_counter_ + 1) & 0x3; - - if( threadIdx.x == 0 ) { - spin_wait_(barrier, step, expected); - } - // CTA waits for thread 0 - __syncthreads(); - } - - int phase_counter_; - int * b0_; - int * b1_; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Reducer : public Reducer { - - using InterCTASync = InterCTASync; - using Base = Reducer; - using Type = typename Base::Type; - - enum { SMEM_BYTES = Base::SMEM_BYTES }; - - enum { WS_BARRIER_BYTES = 2 * sizeof(int) }; - enum { WS_DATA_BYTES = WARPS_M * CTAS_PER_ROW * sizeof(T) }; - - // size of the barriers + temporary result per CTA (multiply with CTAS_PER_ROW to get total) - enum { WORKSPACE_BYTES_PER_GROUP = Base::WORKSPACE_BYTES_PER_GROUP + WS_BARRIER_BYTES + WS_DATA_BYTES }; - - template - inline __device__ Reducer(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem) - : Base(params, bidm, bidn, warp_m, warp_n, lane, smem) - , inter_cta_(params, bidm, bidn) - , bidn_(bidn) // CTA id within the group. - , w0_(static_cast(params.workspace) + (bidm * WARPS_M + warp_m) * CTAS_PER_ROW) - , w1_(w0_ + params.ctas_per_col * WARPS_M * CTAS_PER_ROW) - { - } - - template - inline __device__ T allreduce(T data, Op &op) { - data = Base::reduce(data, op); - // We switch workspace every iteration. - T *workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_; - - // Warp leaders 0 hold the CTA-local results. - if( this->warp_n_ == 0 && this->lane_ == 0 ) { - workspace[bidn_] = data; - } - inter_cta_.sync(); - static_assert(CTAS_PER_ROW <= 32); - T total = Zeros::get(); - if(this->lane_ < CTAS_PER_ROW){ - total = workspace[this->lane_]; - } - total = Reducer::allreduce_(total, op); - - return total; - } - - InterCTASync inter_cta_; - - T *w0_; - T *w1_; - int bidn_; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Reducer { - - using Type = T; - enum { SMEM_BYTES = 0 }; - enum { WORKSPACE_BYTES_PER_GROUP = 0 }; - - enum { THREADS_PER_WARP = 32 }; - - template - inline __device__ Reducer(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem) - : warp_n_(warp_n) - , lane_(lane) - { - } - - template - static inline __device__ T allreduce_(T data, Op &op) { - #pragma unroll - for( int it = 1; it < THREADS_PER_WARP; it *= 2 ) { - data = op(data, warp_shuffle_xor(data, it)); - } - return data; - } - - template - inline __device__ T allreduce(T data, Op &op) { - return allreduce_(data, op); - } - - template - inline __device__ T reduce(T data, Op &op){ - // only lane 0 holds the result! - #pragma unroll - for( int it = THREADS_PER_WARP / 2; it > 0; it /= 2 ) { - data = op(data, warp_shuffle_down(data, it)); - } - return data; - } - int warp_n_; - int lane_; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Reducer : public Reducer { - - using Base = Reducer; - - using Type = T; - - enum { SMEM_BYTES = Base::SMEM_BYTES + WARPS_M * WARPS_N * sizeof(T) * 2 }; - enum { WORKSPACE_BYTES_PER_GROUP = 0 }; - - enum { THREADS_PER_WARP = 32 }; - - template - inline __device__ Reducer(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem) - : Base(params, bidm, bidn, warp_m, warp_n, lane, smem) - , use0_(true) - { - smem0_ = &static_cast(smem)[warp_m * WARPS_N]; - smem1_ = smem0_ + WARPS_M * WARPS_N; - } - - template - inline __device__ T allreduce(T data, Op & op) { - T * smem = use0_ ? smem0_ : smem1_; - use0_ = !use0_; - data = Base::reduce(data, op); - if( this->lane_ == 0 ) { - smem[this->warp_n_] = data; - } - __syncthreads(); - T out = Zeros::get(); - #pragma unroll - for( int it = 0; it < WARPS_N; it++ ) { - out = op(out, smem[it]); - } - return out; - } - - template - inline __device__ T reduce(T data, Op &op) { - T * smem = use0_ ? smem0_ : smem1_; - use0_ = !use0_; - // only intra-CTA group leader holds the result! - data = Base::reduce(data, op); - if( this->lane_ == 0 ) { - smem[this->warp_n_] = data; - } - __syncthreads(); - T out = Zeros::get(); - if( this->warp_n_ == 0 && this->lane_ == 0 ) { - #pragma unroll - for( int it = 0; it < WARPS_N; it++ ) { - out = op(out, smem[it]); - } - } - return out; - } - - T * smem0_; - T * smem1_; - bool use0_; - -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ void warp_chan_upd_dynamic(T &m_a, T &m2_a, int_t &n_a, int num_active){ - //Assume at least leftmost is valid and init: step = next_pow2(num_active) / 2 (might get NaN otherwise) - const int highest_bit_set = (8 * sizeof(num_active)) - __clz(num_active - 1); - - #pragma unroll - for( int step = (1 << (highest_bit_set - 1)); step > 0; step /= 2 ) { - // Exchange - int_t n_b = warp_shuffle_down(n_a, step); - T m_b = warp_shuffle_down(m_a, step); - T m2_b = warp_shuffle_down(m2_a, step); - - // Update - const int_t n_ab = n_a + n_b; // We can handle one of them being 0, not both. - const T rn_ab = 1.f / n_ab; // Might have different n per thread, otherwise this would simplify :( - const T delta = m_a - m_b; - const float m2_ab = m2_a + m2_b + delta * delta * n_a * n_b * rn_ab; - const float m_ab = (n_a * m_a + n_b * m_b) * rn_ab; - - n_a = n_ab; - m_a = m_ab; - m2_a = m2_ab; - } - // Intra-warp broadcast (only lane 0 has valid stats). - m_a = __shfl_sync(uint32_t(-1), m_a, 0); - m2_a = __shfl_sync(uint32_t(-1), m2_a, 0); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Stats { - // This could be done generically with the Reducer. But then we would have to exchange 3 instead of 2 fields. - - using InterCTASync = InterCTASync; - using BlockStats = Stats; - using stats_t = typename BlockStats::stats_t; - - enum { SMEM_BYTES = BlockStats::SMEM_BYTES }; - - template - inline __device__ Stats(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem) - : inter_cta_(params, bidm, bidn) - , block_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem) - , bidn_(bidn) // CTA id within the group. - , w0_(static_cast(params.workspace) + (bidm * WARPS_M + warp_m) * CTAS_PER_ROW) - , w1_(w0_ + params.ctas_per_col * WARPS_M * CTAS_PER_ROW) - , warp_n_(warp_n) - , lane_(lane) - { - } - - template - inline __device__ stats_t compute(const T (&elts)[N], const T rn) { - constexpr T ELTS_PER_ROW_PER_CTA = N * WARPS_N * THREADS_PER_WARP; - // TODO rn is not really needed here.. - constexpr T block_rn = 1.f / T(ELTS_PER_ROW_PER_CTA); - stats_t block_stats = block_stats_.compute(elts, block_rn); - - stats_t *workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_; - - if( warp_n_ == 0 && lane_ == 0 ) { - workspace[bidn_] = block_stats; - } - - // Wait for all CTAS_PER_ROW CTAS in the group to have written their result. - inter_cta_.sync(); - - T n = Zeros::get(); - T m = Zeros::get(); - T m2 = Zeros::get(); - - // Assume CTA group size in N less than 32, such that we can finalize with a single warp. - static_assert(CTAS_PER_ROW <= 32); - - // Every warp does the final reduction locally. - if( lane_ < CTAS_PER_ROW ) { - stats_t result = workspace[lane_]; - n = ELTS_PER_ROW_PER_CTA; - m = layer_norm::Get<0>::of(result); - m2 = layer_norm::Get<1>::of(result); - } - - warp_chan_upd_dynamic(m, m2, n, CTAS_PER_ROW); - - return { m, m2 }; - } - - InterCTASync inter_cta_; - BlockStats block_stats_; - - stats_t *w0_; - stats_t *w1_; - int bidn_; - int warp_n_; - int lane_; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Stats { - - using WarpStats = Stats; - using stats_t = typename WarpStats::stats_t; - - enum { SMEM_BYTES = WARPS_M * WARPS_N * sizeof(stats_t) * 2 }; - - template - inline __device__ Stats(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem) - : warp_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem) - , use0_(true) - { - smem0_ = static_cast(smem) + warp_m * WARPS_N; - smem1_ = smem0_ + WARPS_M * WARPS_N; - } - - template - inline __device__ stats_t compute(const T (&elts)[N], const T row_norm_factor, - function_t valid_elts_in_warp_fn, const int num_valid_elts = N) { - stats_t * smem = use0_ ? smem0_ : smem1_; - use0_ = !use0_; - // Compute warp local for all WARPS_N - const auto warp_n = warp_stats_.reducer_.warp_n_; - const T warp_norm_factor = 1.f / T(Is_even_cols ? N * THREADS_PER_WARP : valid_elts_in_warp_fn(warp_n)); - stats_t warp_stats = warp_stats_.template compute( - elts, warp_norm_factor, valid_elts_in_warp_fn, num_valid_elts - ); - - //Each warp warp leader stores its stats - const auto lane = warp_stats_.reducer_.lane_; - if( lane == 0 ) { - smem[warp_n] = warp_stats; - } - __syncthreads(); - - int n = 0;; - T m = Zeros::get(); - T m2 = Zeros::get(); - - // Assume that there are less than 32 warps, such that we can finalize with a single warp - static_assert(WARPS_N <= 32); - if(lane < WARPS_N){ - stats_t result = smem[lane]; - n = Is_even_cols ? N * THREADS_PER_WARP : valid_elts_in_warp_fn(lane); - m = layer_norm::Get<0>::of(result); - m2 = layer_norm::Get<1>::of(result); - } - - warp_chan_upd_dynamic(m, m2, n, WARPS_N); - - return { m, m2 }; - } - WarpStats warp_stats_; - stats_t * smem0_; - stats_t * smem1_; - bool use0_; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Stats { - - using stats_t = typename TypeToVec2::Type; - // The simple Warp reducer. - using Reducer = Reducer; - - enum { SMEM_BYTES = 0 }; - - template - inline __device__ Stats(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem) - : reducer_(params, bidm, bidn, warp_m, warp_n, lane, smem) - { - } - - template - inline __device__ stats_t compute(const T (&elts)[N], const T row_norm_factor, - // const int valid_elts_in_warp_ignored_, const int num_valid_elts = N) { - function_t valid_elts_in_warp_fn, const int num_valid_elts = N) { - - auto sum = Sum(); - - T m = Zeros::get(); - #pragma unroll - for( int it = 0; it < N; it++ ) { - if (Is_even_cols || (it < num_valid_elts)) { - m += elts[it]; - } - } - m = reducer_.allreduce(m, sum) * row_norm_factor; - - T m2 = Zeros::get(); - #pragma unroll - for( int it = 0; it < N; it++ ) { - if (Is_even_cols || (it < num_valid_elts)) { - T diff = (elts[it] - m); - m2 += diff * diff; - } - } - m2 = reducer_.allreduce(m2, sum); - - return {m, m2}; - } - - Reducer reducer_; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace layer_norm diff --git a/candle-kernels/src/static_switch.h b/candle-kernels/src/static_switch.h deleted file mode 100644 index 7920ac04..00000000 --- a/candle-kernels/src/static_switch.h +++ /dev/null @@ -1,25 +0,0 @@ -// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h -// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h - -#pragma once - -/// @param COND - a boolean expression to switch by -/// @param CONST_NAME - a name given for the constexpr bool variable. -/// @param ... - code to execute for true and false -/// -/// Usage: -/// ``` -/// BOOL_SWITCH(flag, BoolConst, [&] { -/// some_function(...); -/// }); -/// ``` -#define BOOL_SWITCH(COND, CONST_NAME, ...) \ - [&] { \ - if (COND) { \ - constexpr bool CONST_NAME = true; \ - return __VA_ARGS__(); \ - } else { \ - constexpr bool CONST_NAME = false; \ - return __VA_ARGS__(); \ - } \ - }()