From 3a7304cb0dbdf8ceeab8a4f5cf9b8e7ced822e20 Mon Sep 17 00:00:00 2001 From: Jeroen Vlek Date: Fri, 5 Jan 2024 11:59:46 +0100 Subject: [PATCH 1/8] add link to gpt-from-scratch-rs (#1525) --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index a03367a5..93cbccc4 100644 --- a/README.md +++ b/README.md @@ -158,6 +158,7 @@ And then head over to - [`candle-ext`](https://github.com/mokeyish/candle-ext): An extension library to Candle that provides PyTorch functions not currently available in Candle. - [`kalosm`](https://github.com/floneum/floneum/tree/master/interfaces/kalosm): A multi-modal meta-framework in Rust for interfacing with local pre-trained models with support for controlled generation, custom samplers, in-memory vector databases, audio transcription, and more. - [`candle-sampling`](https://github.com/EricLBuehler/candle-sampling): Sampling techniques for Candle. +- [`gpt-from-scratch-rs`](https://github.com/jeroenvlek/gpt-from-scratch-rs): A port of Andrej Karpathy's _Let's build GPT_ tutorial on YouTube showcasing the Candle API on a toy problem. If you have an addition to this list, please submit a pull request. From 8d1a57c9a0465b201e4e9e410e2b8fcde37b35f7 Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Fri, 5 Jan 2024 18:28:55 +0100 Subject: [PATCH 2/8] chore: update flash attention kernels (#1518) * chore: update flash attention kernels * fmt * remove unused kernels * force f32 * correct stride --- candle-flash-attn/kernels/alibi.h | 62 +++ candle-flash-attn/kernels/block_info.h | 13 +- candle-flash-attn/kernels/flash.h | 54 ++- candle-flash-attn/kernels/flash_api.cu | 40 +- .../kernels/flash_fwd_hdim128_bf16_sm80.cu | 13 +- .../kernels/flash_fwd_hdim128_fp16_sm80.cu | 26 +- .../kernels/flash_fwd_hdim160_bf16_sm80.cu | 11 +- .../kernels/flash_fwd_hdim160_fp16_sm80.cu | 21 +- .../kernels/flash_fwd_hdim192_bf16_sm80.cu | 12 +- .../kernels/flash_fwd_hdim192_fp16_sm80.cu | 21 +- .../kernels/flash_fwd_hdim224_bf16_sm80.cu | 5 +- .../kernels/flash_fwd_hdim224_fp16_sm80.cu | 5 +- .../kernels/flash_fwd_hdim256_bf16_sm80.cu | 5 +- .../kernels/flash_fwd_hdim256_fp16_sm80.cu | 5 +- .../kernels/flash_fwd_hdim32_bf16_sm80.cu | 4 +- .../kernels/flash_fwd_hdim32_fp16_sm80.cu | 17 +- .../kernels/flash_fwd_hdim64_bf16_sm80.cu | 13 +- .../kernels/flash_fwd_hdim64_fp16_sm80.cu | 20 +- .../kernels/flash_fwd_hdim96_bf16_sm80.cu | 11 +- .../kernels/flash_fwd_hdim96_fp16_sm80.cu | 21 +- candle-flash-attn/kernels/flash_fwd_kernel.h | 282 ++++++----- .../kernels/flash_fwd_launch_template.h | 63 +-- candle-flash-attn/kernels/kernel_traits.h | 77 +++- .../kernels/kernel_traits_sm90.h | 159 +++++++ candle-flash-attn/kernels/softmax.h | 57 ++- candle-flash-attn/kernels/utils.h | 92 ++-- candle-flash-attn/src/ffi.rs | 8 +- candle-flash-attn/src/lib.rs | 436 +++++++++++++++++- 28 files changed, 1087 insertions(+), 466 deletions(-) create mode 100644 candle-flash-attn/kernels/alibi.h create mode 100644 candle-flash-attn/kernels/kernel_traits_sm90.h diff --git a/candle-flash-attn/kernels/alibi.h b/candle-flash-attn/kernels/alibi.h new file mode 100644 index 00000000..1afb3687 --- /dev/null +++ b/candle-flash-attn/kernels/alibi.h @@ -0,0 +1,62 @@ +#include + +#include + +#include +#include + +#include "utils.h" + +namespace flash { + +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void apply_alibi(Tensor &tensor, + const int col_idx_offset_, + const int max_seqlen_k, + const int row_idx_offset, + const int max_seqlen_q, + const int warp_row_stride, + const float alibi_slope) { + // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) + static_assert(Layout::rank == 2, "Only support 2D Tensor"); + const int lane_id = threadIdx.x % 32; + const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; + if constexpr (Is_causal) { // Simpler, we add the same bias vector to all rows + #pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; + #pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx; + } + } + } + } else { // Bias depends on both row_idx and col_idx + #pragma unroll + for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { + const int row_idx_base = row_idx_offset + mi * warp_row_stride; + #pragma unroll + for (int i = 0; i < size<0, 0>(tensor); ++i) { + const int row_idx = row_idx_base + i * 8; + #pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; + #pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx); + } + } + } + } + } +} + +} // namespace flash diff --git a/candle-flash-attn/kernels/block_info.h b/candle-flash-attn/kernels/block_info.h index 94251a41..65435e51 100644 --- a/candle-flash-attn/kernels/block_info.h +++ b/candle-flash-attn/kernels/block_info.h @@ -14,9 +14,12 @@ struct BlockInfo { template __device__ BlockInfo(const Params ¶ms, const int bidb) : sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb]) - , sum_s_k(!Varlen || params.cu_seqlens_k == nullptr ? -1 : params.cu_seqlens_k[bidb]) + , sum_s_k(!Varlen || params.cu_seqlens_k == nullptr || !params.is_seqlens_k_cumulative ? -1 : params.cu_seqlens_k[bidb]) , actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q) - , actual_seqlen_k(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : params.cu_seqlens_k[bidb + 1] - sum_s_k) + // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb]. + // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K. + , seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb])) + , actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] : seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)) { } @@ -32,8 +35,10 @@ struct BlockInfo { const int sum_s_q; const int sum_s_k; - const uint32_t actual_seqlen_q; - const uint32_t actual_seqlen_k; + const int actual_seqlen_q; + // We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0. + const int seqlen_k_cache; + const int actual_seqlen_k; }; //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/candle-flash-attn/kernels/flash.h b/candle-flash-attn/kernels/flash.h index be4ae0ca..80b517e9 100644 --- a/candle-flash-attn/kernels/flash.h +++ b/candle-flash-attn/kernels/flash.h @@ -7,15 +7,6 @@ #include #include -// #ifdef OLD_GENERATOR_PATH -// #include -// #else -// #include -// #endif -// -// #include - - constexpr int TOTAL_DIM = 0; constexpr int H_DIM = 1; constexpr int D_DIM = 2; @@ -53,6 +44,7 @@ struct Flash_fwd_params : public Qkv_params { // The O matrix (output). void * __restrict__ o_ptr; + void * __restrict__ oaccum_ptr; // The stride between rows of O. index_t o_batch_stride; @@ -64,9 +56,10 @@ struct Flash_fwd_params : public Qkv_params { // The pointer to the softmax sum. void * __restrict__ softmax_lse_ptr; + void * __restrict__ softmax_lseaccum_ptr; // The dimensions. - int b, seqlen_q, seqlen_k, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded; + int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim; // The scaling factors for the kernel. float scale_softmax; @@ -76,8 +69,30 @@ struct Flash_fwd_params : public Qkv_params { int * __restrict__ cu_seqlens_q; int * __restrict__ cu_seqlens_k; + // If provided, the actual length of each k sequence. + int * __restrict__ seqused_k; + int *__restrict__ blockmask; + // The K_new and V_new matrices. + void * __restrict__ knew_ptr; + void * __restrict__ vnew_ptr; + + // The stride between rows of the Q, K and V matrices. + index_t knew_batch_stride; + index_t vnew_batch_stride; + index_t knew_row_stride; + index_t vnew_row_stride; + index_t knew_head_stride; + index_t vnew_head_stride; + + // The cos and sin matrices for rotary embedding. + void * __restrict__ rotary_cos_ptr; + void * __restrict__ rotary_sin_ptr; + + // The indices to index into the KV cache. + int *__restrict__ cache_batch_idx; + // The dropout probability (probability of keeping an activation). float p_dropout; // uint32_t p_dropout_in_uint; @@ -88,11 +103,22 @@ struct Flash_fwd_params : public Qkv_params { float rp_dropout; float scale_softmax_rp_dropout; - // Random state. - // at::PhiloxCudaState philox_args; + // Local window size + int window_size_left, window_size_right; bool is_bf16; bool is_causal; + + // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb]. + // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K. + bool is_seqlens_k_cumulative; + + bool is_rotary_interleaved; + + int num_splits; // For split-KV version + + void * __restrict__ alibi_slopes_ptr; + index_t alibi_slopes_batch_stride; }; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -132,10 +158,14 @@ struct Flash_bwd_params : public Flash_fwd_params { // The pointer to the softmax d sum. void *__restrict__ dsoftmax_sum; + + bool deterministic; + index_t dq_accum_split_stride; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure); diff --git a/candle-flash-attn/kernels/flash_api.cu b/candle-flash-attn/kernels/flash_api.cu index 72991257..8113dbc7 100644 --- a/candle-flash-attn/kernels/flash_api.cu +++ b/candle-flash-attn/kernels/flash_api.cu @@ -1,17 +1,15 @@ #include "flash_fwd_launch_template.h" -// void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { -// FWD_HEADDIM_SWITCH(params.d, [&] { -// run_mha_fwd_(params, stream); -// }); -// } - -void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { - FP16_SWITCH(!params.is_bf16, [&] { - FWD_HEADDIM_SWITCH(params.d, [&] { - run_mha_fwd_(params, stream); - }); - }); +void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream, bool force_split_kernel=false) { + FP16_SWITCH(!params.is_bf16, [&] { + FWD_HEADDIM_SWITCH(params.d, [&] { +// if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0 + run_mha_fwd_(params, stream); +// } else { +// run_mha_fwd_splitkv_dispatch(params, stream); +// } + }); + }); } extern "C" void run_mha( @@ -20,6 +18,7 @@ extern "C" void run_mha( void *v_ptr, void *o_ptr, void *softmax_lse_ptr, + void *alibi_slopes_ptr, int32_t *cu_seqlens_q_ptr, int32_t *cu_seqlens_k_ptr, @@ -28,6 +27,7 @@ extern "C" void run_mha( uint32_t k_batch_stride, uint32_t v_batch_stride, uint32_t o_batch_stride, + uint32_t alibi_slopes_batch_stride, uint32_t q_row_stride, uint32_t k_row_stride, @@ -51,8 +51,11 @@ extern "C" void run_mha( uint32_t seqlen_q_rounded, uint32_t seqlen_k_rounded, + int is_bf16, int is_causal, - int is_bf16 + + int window_size_left, + int window_size_right ) { Flash_fwd_params params; // Reset the parameters @@ -65,12 +68,14 @@ extern "C" void run_mha( params.o_ptr = o_ptr; params.softmax_lse_ptr = softmax_lse_ptr; + params.alibi_slopes_ptr = alibi_slopes_ptr; // All stride are in elements, not bytes. params.q_batch_stride = q_batch_stride; params.k_batch_stride = k_batch_stride; params.v_batch_stride = v_batch_stride; params.o_batch_stride = o_batch_stride; + params.alibi_slopes_batch_stride = alibi_slopes_batch_stride; params.q_row_stride = q_row_stride; params.k_row_stride = k_row_stride; @@ -92,7 +97,6 @@ extern "C" void run_mha( params.seqlen_k_rounded = seqlen_k_rounded; params.d = d; params.d_rounded = d_rounded; - params.is_causal = is_causal; // Set the different scale values. params.scale_softmax = softmax_scale; @@ -106,6 +110,14 @@ extern "C" void run_mha( params.cu_seqlens_q = cu_seqlens_q_ptr; params.cu_seqlens_k = cu_seqlens_k_ptr; params.p_ptr = nullptr; // used for `return_softmax`. + params.seqused_k = nullptr; + + params.is_causal = is_causal; + params.window_size_left = window_size_left; + params.window_size_right = window_size_right; + + params.is_seqlens_k_cumulative = true; + params.num_splits = 1; cudaStream_t stream = 0; // Use the default stream. run_mha_fwd(params, stream); diff --git a/candle-flash-attn/kernels/flash_fwd_hdim128_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim128_bf16_sm80.cu index 654400a7..6ffa4126 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim128_bf16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim128_bf16_sm80.cu @@ -1,19 +1,10 @@ // Copyright (c) 2023, Tri Dao. - // Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" -// template<> -// void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { -// using elem_type = cutlass::bfloat16_t; -// if (params.p_dropout == 1.f) { -// run_flash_fwd, false>(params, stream); -// } else { -// run_flash_fwd, true>(params, stream); -// } -// } template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim128(params, stream); -} \ No newline at end of file +} diff --git a/candle-flash-attn/kernels/flash_fwd_hdim128_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim128_fp16_sm80.cu index 5b7254a9..19b005ad 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim128_fp16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim128_fp16_sm80.cu @@ -1,32 +1,10 @@ // Copyright (c) 2023, Tri Dao. - // Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" -// template<> -// void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { -// using elem_type = cutlass::half_t; -// if (params.p_dropout == 1.f) { -// // Using 8 warps (128 x 128 and 256 x 64) is 28% slower for seqlen=2k -// run_flash_fwd, false>(params, stream); -// // run_flash_fwd, false>(params, stream); -// // run_flash_fwd, false>(params, stream); -// // run_flash_fwd, false>(params, stream); -// run_flash_fwd, false>(params, stream); -// run_flash_fwd, false>(params, stream); -// run_flash_fwd, false>(params, stream); -// // 1st ones are good for H100, A100 -// // 2nd one is good for A6000 bc we get slightly better occupancy -// } else { -// run_flash_fwd, true>(params, stream); -// run_flash_fwd, true>(params, stream); -// run_flash_fwd, true>(params, stream); -// // 1st one is good for H100, A100, A6000 -// } -// } - template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim128(params, stream); -} \ No newline at end of file +} diff --git a/candle-flash-attn/kernels/flash_fwd_hdim160_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim160_bf16_sm80.cu index 6a9d60c3..f674f481 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim160_bf16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim160_bf16_sm80.cu @@ -1,17 +1,10 @@ // Copyright (c) 2023, Tri Dao. - // Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" -// template<> -// void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { -// using elem_type = cutlass::bfloat16_t; -// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { -// run_flash_fwd, Is_dropout>(params, stream); -// }); -// } template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim160(params, stream); -} \ No newline at end of file +} diff --git a/candle-flash-attn/kernels/flash_fwd_hdim160_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim160_fp16_sm80.cu index 6c40a164..afd0a8a3 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim160_fp16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim160_fp16_sm80.cu @@ -1,27 +1,10 @@ // Copyright (c) 2023, Tri Dao. - // Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" -// template<> -// void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { -// using elem_type = cutlass::half_t; -// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { -// run_flash_fwd, Is_dropout>(params, stream); -// run_flash_fwd, Is_dropout>(params, stream); -// run_flash_fwd, Is_dropout>(params, stream); -// run_flash_fwd, Is_dropout>(params, stream); -// // run_flash_fwd>(params, stream); -// // run_flash_fwd>(params, stream); -// // run_flash_fwd>(params, stream); -// // run_flash_fwd>(params, stream); -// // run_flash_fwd>(params, stream); -// // For A6000, no-causal, 1st is fastest. causal, 4th is fastest. -// // For A100, H100, 1st is fastest. -// }); -// } template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim160(params, stream); -} \ No newline at end of file +} diff --git a/candle-flash-attn/kernels/flash_fwd_hdim192_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim192_bf16_sm80.cu index d2f4cba7..aa91bdd6 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim192_bf16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim192_bf16_sm80.cu @@ -1,16 +1,10 @@ // Copyright (c) 2023, Tri Dao. - // Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" -// template<> -// void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { -// using elem_type = cutlass::bfloat16_t; -// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { -// run_flash_fwd, Is_dropout>(params, stream); -// }); -// } -template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim192(params, stream); } diff --git a/candle-flash-attn/kernels/flash_fwd_hdim192_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim192_fp16_sm80.cu index 2875c926..37a96526 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim192_fp16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim192_fp16_sm80.cu @@ -1,27 +1,10 @@ // Copyright (c) 2023, Tri Dao. - // Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" -// template<> -// void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { -// using elem_type = cutlass::half_t; -// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { -// run_flash_fwd, Is_dropout>(params, stream); -// run_flash_fwd, Is_dropout>(params, stream); -// run_flash_fwd, Is_dropout>(params, stream); -// // This one is slightly faster for causal? -// // run_flash_fwd>(params, stream); -// // run_flash_fwd>(params, stream); -// // run_flash_fwd>(params, stream); -// // run_flash_fwd>(params, stream); -// // run_flash_fwd>(params, stream); -// }); -// // For A100 H100, 1st is faster with dropout, 3rd is faster without dropout -// // For A6000, 1st is faster when causal, 3rd is faster when not causal -// } template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim192(params, stream); -} \ No newline at end of file +} diff --git a/candle-flash-attn/kernels/flash_fwd_hdim224_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim224_bf16_sm80.cu index 982fe7ea..167a0df2 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim224_bf16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim224_bf16_sm80.cu @@ -1,9 +1,10 @@ // Copyright (c) 2023, Tri Dao. - // Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" -template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim224(params, stream); } diff --git a/candle-flash-attn/kernels/flash_fwd_hdim224_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim224_fp16_sm80.cu index 4c083f7b..58ffe75c 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim224_fp16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim224_fp16_sm80.cu @@ -1,9 +1,10 @@ // Copyright (c) 2023, Tri Dao. - // Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" -template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim224(params, stream); } diff --git a/candle-flash-attn/kernels/flash_fwd_hdim256_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim256_bf16_sm80.cu index cb074a95..1b370141 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim256_bf16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim256_bf16_sm80.cu @@ -1,9 +1,10 @@ // Copyright (c) 2023, Tri Dao. - // Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" -template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim256(params, stream); } diff --git a/candle-flash-attn/kernels/flash_fwd_hdim256_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim256_fp16_sm80.cu index ddf5e132..9f35129c 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim256_fp16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim256_fp16_sm80.cu @@ -1,9 +1,10 @@ // Copyright (c) 2023, Tri Dao. - // Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" -template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim256(params, stream); } diff --git a/candle-flash-attn/kernels/flash_fwd_hdim32_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim32_bf16_sm80.cu index 81e359e1..770de6fc 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim32_bf16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim32_bf16_sm80.cu @@ -1,10 +1,10 @@ // Copyright (c) 2023, Tri Dao. - // Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim32(params, stream); -} \ No newline at end of file +} diff --git a/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_sm80.cu index 91e6331e..8dbf8b94 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_sm80.cu @@ -1,23 +1,10 @@ // Copyright (c) 2023, Tri Dao. - // Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" -// template<> -// void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { -// using elem_type = cutlass::half_t; -// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { -// run_flash_fwd, Is_dropout>(params, stream); -// // For dropout there might be a lot of register spilling? -// // These two are very slow due to register spilling -// // run_flash_fwd>(params, stream); -// // run_flash_fwd>(params, stream); -// // This one is slightly slower -// // run_flash_fwd>(params, stream); -// }); -// } template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim32(params, stream); -} \ No newline at end of file +} diff --git a/candle-flash-attn/kernels/flash_fwd_hdim64_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim64_bf16_sm80.cu index fffcbebb..22eac878 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim64_bf16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim64_bf16_sm80.cu @@ -1,19 +1,10 @@ // Copyright (c) 2023, Tri Dao. - // Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" -// template<> -// void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { -// using elem_type = cutlass::bfloat16_t; -// if (params.p_dropout == 1.f) { -// run_flash_fwd, false>(params, stream); -// } else { -// run_flash_fwd, true>(params, stream); -// } -// } template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim64(params, stream); -} \ No newline at end of file +} diff --git a/candle-flash-attn/kernels/flash_fwd_hdim64_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim64_fp16_sm80.cu index 01bd1716..e6da5dd2 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim64_fp16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim64_fp16_sm80.cu @@ -1,26 +1,10 @@ // Copyright (c) 2023, Tri Dao. - // Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" -// template<> -// void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { -// using elem_type = cutlass::half_t; -// if (params.p_dropout == 1.f) { -// // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower -// // Using block size (64 x 256) is 27% slower for seqlen=2k -// // Using block size (256 x 64) is 85% slower for seqlen=2k, because of register spilling -// run_flash_fwd, false>(params, stream); -// run_flash_fwd, false>(params, stream); -// run_flash_fwd, false>(params, stream); -// } else { -// run_flash_fwd, true>(params, stream); -// run_flash_fwd, true>(params, stream); -// run_flash_fwd, true>(params, stream); -// } -// } template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim64(params, stream); -} \ No newline at end of file +} diff --git a/candle-flash-attn/kernels/flash_fwd_hdim96_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim96_bf16_sm80.cu index b0b27db5..9c003540 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim96_bf16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim96_bf16_sm80.cu @@ -1,17 +1,10 @@ // Copyright (c) 2023, Tri Dao. - // Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" -// template<> -// void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { -// using elem_type = cutlass::bfloat16_t; -// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { -// run_flash_fwd, Is_dropout>(params, stream); -// }); -// } template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim96(params, stream); -} \ No newline at end of file +} diff --git a/candle-flash-attn/kernels/flash_fwd_hdim96_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim96_fp16_sm80.cu index 820b63cb..8108696a 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim96_fp16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim96_fp16_sm80.cu @@ -1,23 +1,10 @@ // Copyright (c) 2023, Tri Dao. - // Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" -// template<> -// void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { -// using elem_type = cutlass::half_t; -// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { -// run_flash_fwd, Is_dropout>(params, stream); -// run_flash_fwd, Is_dropout>(params, stream); -// // This 3rd one is good for H100, and A100, A6000 -// run_flash_fwd, Is_dropout>(params, stream); -// run_flash_fwd, Is_dropout>(params, stream); -// // These two are always slower -// // run_flash_fwd>(params, stream); -// // run_flash_fwd>(params, stream); -// }); -// } -template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim96(params, stream); -} \ No newline at end of file +} diff --git a/candle-flash-attn/kernels/flash_fwd_kernel.h b/candle-flash-attn/kernels/flash_fwd_kernel.h index 232dea0d..05f5f701 100644 --- a/candle-flash-attn/kernels/flash_fwd_kernel.h +++ b/candle-flash-attn/kernels/flash_fwd_kernel.h @@ -4,20 +4,18 @@ #pragma once -#include #include -#include #include #include #include -#include #include "block_info.h" #include "kernel_traits.h" #include "utils.h" #include "softmax.h" -#include "philox.cuh" + +#include "alibi.h" namespace flash { @@ -25,49 +23,6 @@ using namespace cute; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -CUTE_HOST_DEVICE -auto -make_tiled_copy_A_warpcontiguousM(Copy_Atom const& copy_atom, - TiledMMA const& tiled_mma) { - using TileShape_MNK = typename TiledMMA::TiledShape_MNK; - using AtomShape_MNK = typename TiledMMA::AtomShape_MNK; - constexpr int AtomShape_M = decltype(size<0>(AtomShape_MNK{}))::value; - constexpr int kNWarps = decltype(size<0>(TileShape_MNK{}))::value / AtomShape_M; - constexpr int MMAStride_M = MMA_M * AtomShape_M; - auto t = make_tile(Layout, Int>, - Stride<_1, Int> >{}, - make_layout(size<2>(TileShape_MNK{}))); - // if (cute::thread0()) {printf("make_tiled_copy_A_warpcontiguousM "); print(t); printf("\n"); } - return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutA_TV(), t); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -CUTE_HOST_DEVICE -auto -make_tiled_copy_C_warpcontiguousM(Copy_Atom const& copy_atom, - TiledMMA const& tiled_mma) { - using TileShape_MNK = typename TiledMMA::TiledShape_MNK; - using AtomShape_MNK = typename TiledMMA::AtomShape_MNK; - constexpr int AtomShape_M = decltype(size<0>(AtomShape_MNK{}))::value; - constexpr int kNWarps = decltype(size<0>(TileShape_MNK{}))::value / AtomShape_M; - constexpr int MMAStride_M = MMA_M * AtomShape_M; - auto t = make_tile(Layout, Int>, - Stride<_1, Int> >{}, - // TODO: Shouldn't this be size<1>? - make_layout(size<2>(TileShape_MNK{}))); - // if (cute::thread0()) {printf("make_tiled_copy_C_warpcontiguousM "); print(t); printf("\n"); } - return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutC_TV(), t); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - template inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &scores_max, Tensor1 &scores_sum, Tensor2 &acc_o, float softmax_scale_log2) { @@ -77,7 +32,7 @@ inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &scores_max, T flash::reduce_sum(scores, scores_sum); } else { Tensor scores_max_prev = make_fragment_like(scores_max); - copy(scores_max, scores_max_prev); + cute::copy(scores_max, scores_max_prev); flash::template reduce_max(scores, scores_max); // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); @@ -103,23 +58,22 @@ inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &scores_max, T template inline __device__ void write_softmax_to_gmem( - Tensor const &tOrP, Tensor &tPgP, TiledCopy gmem_thr_copy_P + Tensor const &tOrP, Tensor &tPgP, TiledCopy gmem_tiled_copy_P ) { // Reshape tOrP from (8, MMA_M, MMA_N) to (8, MMA_M * MMA_N) Layout l = tOrP.layout(); Tensor tPrP = make_tensor(tOrP.data(), make_layout(get<0>(l), make_layout(get<1>(l), get<2>(l)))); CUTE_STATIC_ASSERT_V(size<2>(tPgP) == _1{}); - // TODO(laurent): reactivate the following - // CUTE_STATIC_ASSERT_V(size<1>(tPrP) == size<1>(tPgP)); + CUTE_STATIC_ASSERT_V(size<1>(tPrP) == size<1>(tPgP)); #pragma unroll for (int mi = 0; mi < size<1>(tPrP); ++mi) { - copy(gmem_thr_copy_P, tPrP(_, mi), tPgP(_, mi, 0)); + cute::copy(gmem_tiled_copy_P, tPrP(_, mi), tPgP(_, mi, 0)); } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bidb, const int bidh, const int m_block) { using Element = typename Kernel_traits::Element; @@ -138,16 +92,65 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi constexpr int kNWarps = Kernel_traits::kNWarps; constexpr int MMA_M = kBlockM / decltype(size<0>(typename Kernel_traits::TiledMma::TiledShape_MNK{}))::value; - const BlockInfo binfo(params, bidb); - if (m_block * kBlockM >= binfo.actual_seqlen_q || binfo.actual_seqlen_k == 0) return; + const BlockInfo binfo(params, bidb); + if (m_block * kBlockM >= binfo.actual_seqlen_q) return; + const int n_block_min = !Is_local ? 0 : std::max(0, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN); int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN); - if (Is_causal) { - n_block_max = std::min(n_block_max, cute::ceil_div((m_block + 1) * kBlockM, kBlockN)); + if (Is_causal || Is_local) { + n_block_max = std::min(n_block_max, + cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right, kBlockN)); // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { // printf("m_block = %d, n_block_max = %d\n", m_block, n_block_max); // } } + // We exit early and write 0 to gO and gLSE. This also covers the case where actual_seqlen_k == 0. + // Otherwise we might read OOB elements from gK and gV. + if ((Is_causal || Is_local || !Is_even_MN) && n_block_max <= n_block_min) { + // Save seed and offset for backward. If we don't have this here, the 0-th thread block might + // exit early and no one saves the rng state. +// if (Is_dropout && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && tidx == 0) { +// auto seeds = at::cuda::philox::unpack(params.philox_args); +// params.rng_state[0] = std::get<0>(seeds); +// params.rng_state[1] = std::get<1>(seeds); +// params.rng_state[0] = 0; +// params.rng_state[1] = 0; +// } + const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; + const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM; + Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + row_offset_o), + Shape, Int>{}, + make_stride(params.o_row_stride, _1{})); + Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), + Shape>{}, Stride<_1>{}); + + typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); + Tensor tOgO = gmem_thr_copy_O.partition_D(gO); + Tensor tOrO = make_tensor(shape(tOgO)); + clear(tOrO); + // Construct identity layout for sO + Tensor cO = make_identity_tensor(make_shape(size<0>(gO), size<1>(gO))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_O.partition_D(cO); + Tensor tOpO = make_tensor(make_shape(size<2>(tOgO))); + if (!Is_even_K) { + #pragma unroll + for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; } + } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM + ); + #pragma unroll + for (int m = 0; m < size<1>(tOgO); ++m) { + const int row = get<0>(tOcO(0, m, 0)); + if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { gLSE(row) = INFINITY; } + } + return; + } + // if (tidx == 0) { printf("m_block = %d, n_block_min = %d, n_block_max = %d\n", m_block, n_block_min, n_block_max); } // We iterate over the blocks in reverse order. This is because the last block is the only one // that needs masking when we read K and V from global memory. Moreover, iterating in reverse @@ -185,8 +188,10 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); - auto gmem_thr_copy_QKV = typename Kernel_traits::GmemTiledCopyQKV{}.get_thread_slice(tidx); - auto gmem_thr_copy_P = typename Kernel_traits::GmemTiledCopyP{}.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; + auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopyP gmem_tiled_copy_P; + auto gmem_thr_copy_P = gmem_tiled_copy_P.get_thread_slice(tidx); Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); @@ -208,16 +213,18 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Copy Atom retiling // - auto smem_thr_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma).get_thread_slice(tidx); - // auto smem_thr_copy_Q = make_tiled_copy_A_warpcontiguousM(typename Kernel_traits::SmemCopyAtom{}, tiled_mma).get_thread_slice(tidx); + auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); + auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); // if (cute::thread0()) {smem_thr_copy_Q.print_all();} Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); // if (cute::thread0()) {print(tSsQ.layout()); printf("\n");} - auto smem_thr_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma).get_thread_slice(tidx); + auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); + auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); Tensor tSsK = smem_thr_copy_K.partition_S(sK); - auto smem_thr_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma).get_thread_slice(tidx); + auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma); + auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); // TODO: this might need to change if we change the mma instruction in SM70 @@ -268,8 +275,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi Tensor tQrQ = make_fragment_like(tQgQ); // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs - flash::copy(gmem_thr_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, - binfo.actual_seqlen_q - m_block * kBlockM); + flash::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, + binfo.actual_seqlen_q - m_block * kBlockM); if (Kernel_traits::Is_Q_in_regs) { cute::cp_async_fence(); } // // Copy rmem to smem @@ -285,14 +292,14 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi __syncthreads(); Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M - copy(smem_thr_copy_Q, tSsQ, tSrQ_copy_view); + cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view); __syncthreads(); } int n_block = n_block_max - 1; // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. - flash::copy(gmem_thr_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, - binfo.actual_seqlen_k - n_block * kBlockN); + flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, + binfo.actual_seqlen_k - n_block * kBlockN); cute::cp_async_fence(); // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z < 2) { print(tKgK); } // __syncthreads(); @@ -302,7 +309,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi __syncthreads(); Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M - copy(smem_thr_copy_Q, tSsQ, tSrQ_copy_view); + cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view); } // auto seeds = at::cuda::philox::unpack(params.philox_args); @@ -313,13 +320,19 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi clear(acc_o); + float alibi_slope = !Has_alibi ? 0.0f : reinterpret_cast(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax; + // For performance reason, we separate out two kinds of iterations: // those that need masking on S, and those that don't. // We need masking on S for the very last block when K and V has length not multiple of kBlockN. // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks. // We will have at least 1 "masking" iteration. - constexpr int n_masking_steps = Is_causal ? cute::ceil_div(kBlockM, kBlockN) : 1; + // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to + // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1. + constexpr int n_masking_steps = (!Is_causal && !Is_local) + ? 1 + : ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1); #pragma unroll for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) { Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) @@ -330,28 +343,42 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Advance gV if (masking_step > 0) { tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); - flash::copy(gmem_thr_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); + flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); } else { // Clear the smem tiles to account for predicated off loads - flash::copy( - gmem_thr_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN + flash::copy( + gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN ); } cute::cp_async_fence(); flash::gemm( - acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_thr_copy_Q, smem_thr_copy_K + acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, + smem_thr_copy_Q, smem_thr_copy_K ); // if (cute::thread0()) { print(acc_s); } // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); - // if (cute::thread0()) { print(scores); } + // if (cute::thread0()) { print_tensor(scores); } // We don't put the masking before the matmul S = Q K^T because we don't clear sK // for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul // can produce Inf / NaN. - if (!Is_causal) { - if (!Is_even_N) { flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); } + + if (Has_alibi) { + flash::apply_alibi( + scores, + n_block * kBlockN, + binfo.actual_seqlen_k, + m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, + binfo.actual_seqlen_q, + kNWarps * 16, + alibi_slope + ); + } + + if (!Is_causal && !Is_local) { + if (!Is_even_MN) { flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); } } else { // Tensor caccS = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_N) -> (blk_m,blk_n) // Tensor taccScS = thr_mma.partition_C(caccS); // (MMA,MMA_M,MMA_N) @@ -364,20 +391,24 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Idk why it's get<1> and not get<0> of the stride. // if (cute::thread0()) { print(idx_row.layout()); print(stride<1>(idx_row)); printf("stride = %d \n", get<1>(stride<1>(idx_row))); } // I can't get the stride from idx_row - flash::apply_mask_causal(scores, n_block * kBlockN, binfo.actual_seqlen_k, - // m_block * kBlockM + get<0>(idx_row(0)), - m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, - kNWarps * 16); - // m_block * kBlockM + (tidx / 32) * 16, kNWarps * 16); - // m_block * kBlockM + (tidx / 32) * (kBlockM / kNWarps), 16); + flash::apply_mask_local( + scores, n_block * kBlockN, binfo.actual_seqlen_k, + // m_block * kBlockM + get<0>(idx_row(0)), + m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, + binfo.actual_seqlen_q, kNWarps * 16, + params.window_size_left, params.window_size_right + // m_block * kBlockM + (tidx / 32) * 16, kNWarps * 16 + // m_block * kBlockM + (tidx / 32) * (kBlockM / kNWarps), 16 + ); + // if (cute::thread0()) { print_tensor(scores); } } flash::cp_async_wait<0>(); __syncthreads(); - if (n_block > 0) { + if (n_block > n_block_min) { // Advance gK tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); - flash::copy(gmem_thr_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. cute::cp_async_fence(); @@ -385,24 +416,24 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // TODO: when we have key_padding_mask we'll need to Check_inf masking_step == 0 - ? softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2) - : softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); + ? softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2) + : softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); // Convert scores from fp32 to fp16/bf16 Tensor rP = flash::convert_type(scores); // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs(rP.layout())); - uint32_t block_row_idx = m_block * (kBlockM / 16) + tidx / 32; - uint32_t block_col_idx = n_block * (kBlockN / 32); + int block_row_idx = m_block * (kBlockM / 16) + tidx / 32; + int block_col_idx = n_block * (kBlockN / 32); if (Return_softmax) { Tensor tOrP_copy = make_fragment_like(tOrP); - copy(tOrP, tOrP_copy); + cute::copy(tOrP, tOrP_copy); flash::apply_dropout( tOrP_copy, params.p_dropout_in_uint8_t, seed, offset, block_row_idx, block_col_idx, kNWarps ); - flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_thr_copy_P); + flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_tiled_copy_P); tPgP.data() = tPgP.data() + (-kBlockN); } if (Is_dropout) { @@ -411,37 +442,38 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi } // if (cute::thread0()) { print(tOrP); } - flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_thr_copy_V); + flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); // if (cute::thread0()) { print(scores); } // This check is at the end of the loop since we always have at least 1 iteration - if (n_masking_steps > 1 && n_block <= 0) { + if (n_masking_steps > 1 && n_block <= n_block_min) { --n_block; break; } } // These are the iterations where we don't need masking on S - for (; n_block >= 0; --n_block) { + for (; n_block >= n_block_min; --n_block) { Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) clear(acc_s); flash::cp_async_wait<0>(); __syncthreads(); // Advance gV tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); - flash::copy(gmem_thr_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); + flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); cute::cp_async_fence(); flash::gemm( - acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_thr_copy_Q, smem_thr_copy_K + acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, + smem_thr_copy_Q, smem_thr_copy_K ); flash::cp_async_wait<0>(); __syncthreads(); - if (n_block > 0) { + if (n_block > n_block_min) { // Advance gK tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); - flash::copy(gmem_thr_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. cute::cp_async_fence(); @@ -449,22 +481,44 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); - softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); + + if (Has_alibi) { + flash::apply_alibi( + scores, + n_block * kBlockN, + binfo.actual_seqlen_k, + m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, + binfo.actual_seqlen_q, + kNWarps * 16, + alibi_slope + ); + } + + if (Is_local && n_block * kBlockN < (m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right) { + flash::apply_mask_local( + scores, n_block * kBlockN, binfo.actual_seqlen_k, + m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, + binfo.actual_seqlen_q, kNWarps * 16, + params.window_size_left, params.window_size_right + ); + } + + softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); Tensor rP = flash::convert_type(scores); // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs(rP.layout())); - uint32_t block_row_idx = m_block * (kBlockM / 16) + tidx / 32; - uint32_t block_col_idx = n_block * (kBlockN / 32); + int block_row_idx = m_block * (kBlockM / 16) + tidx / 32; + int block_col_idx = n_block * (kBlockN / 32); if (Return_softmax) { Tensor tOrP_copy = make_fragment_like(tOrP); - copy(tOrP, tOrP_copy); + cute::copy(tOrP, tOrP_copy); flash::apply_dropout( tOrP_copy, params.p_dropout_in_uint8_t, seed, offset, block_row_idx, block_col_idx, kNWarps ); - flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_thr_copy_P); + flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_tiled_copy_P); tPgP.data() = tPgP.data() + (-kBlockN); } if (Is_dropout) { @@ -472,7 +526,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi block_row_idx, block_col_idx, kNWarps); } - flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_thr_copy_V); + flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); } // Epilogue @@ -496,15 +550,15 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi Tensor rO = flash::convert_type(acc_o); Tensor sO = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) // Partition sO to match the accumulator partitioning - auto smem_thr_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma).get_thread_slice(tidx); - // auto smem_thr_copy_O = make_tiled_copy_C_warpcontiguousM(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma).get_thread_slice(tidx); + auto smem_tiled_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma); + auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(tidx); Tensor taccOrO = smem_thr_copy_O.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N) // sO has the same size as sQ, so we don't need to sync here. if (Kernel_traits::Share_Q_K_smem) { __syncthreads(); } - copy(smem_thr_copy_O, taccOrO, taccOsO); + cute::copy(smem_tiled_copy_O, taccOrO, taccOsO); const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; @@ -515,14 +569,15 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), Shape>{}, Stride<_1>{}); - auto gmem_thr_copy_O = typename Kernel_traits::GmemTiledCopyO{}.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N) Tensor tOgO = gmem_thr_copy_O.partition_D(gO); __syncthreads(); Tensor tOrO = make_tensor(shape(tOgO)); - copy(gmem_thr_copy_O, tOsO, tOrO); + cute::copy(gmem_tiled_copy_O, tOsO, tOrO); Tensor caccO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) @@ -548,14 +603,15 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; } } // Clear_OOB_K must be false since we don't want to write zeros to gmem - flash::copy( - gmem_thr_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM + flash::copy( + gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM ); } + //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_attn(const Params ¶ms) { const int m_block = blockIdx.x; // The block index for the batch. @@ -571,7 +627,7 @@ inline __device__ void compute_attn(const Params ¶ms) { // the attention matrix. This way, as long as we have the batch, head, and the location of // the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern. - flash::compute_attn_1rowblock(params, bidb, bidh, m_block); + flash::compute_attn_1rowblock(params, bidb, bidh, m_block); } //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/candle-flash-attn/kernels/flash_fwd_launch_template.h b/candle-flash-attn/kernels/flash_fwd_launch_template.h index 398ce077..66ab6206 100644 --- a/candle-flash-attn/kernels/flash_fwd_launch_template.h +++ b/candle-flash-attn/kernels/flash_fwd_launch_template.h @@ -4,15 +4,14 @@ #pragma once -// #include - #include "static_switch.h" #include "flash.h" #include "flash_fwd_kernel.h" -template +template __global__ void flash_fwd_kernel(Flash_fwd_params params) { - flash::compute_attn(params); + static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false + flash::compute_attn(params); } template @@ -26,35 +25,39 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; dim3 grid(num_m_block, params.b, params.h); - // We also use is_even_N to set Unpadded in the BlockInfo constructor, so we need to check - // for cu_seqlens_q as well. - const bool is_even_N = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0; + const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0; const bool is_even_K = params.d == Kernel_traits::kHeadDim; const bool return_softmax = params.p_ptr != nullptr; - BOOL_SWITCH(is_even_N, IsEvenNConst, [&] { + BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { - BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] { - // Will only return softmax if dropout, to reduce compilation time. - auto kernel = &flash_fwd_kernel; - // auto kernel = &flash_fwd_kernel; - // if (smem_size >= 48 * 1024) { - // C10_CUDA_CHECK(cudaFuncSetAttribute( - // kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - // } - int ctas_per_sm; - cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); - // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); - kernel<<>>(params); - // C10_CUDA_KERNEL_LAUNCH_CHECK(); + BOOL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] { + BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] { + BOOL_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] { + // Will only return softmax if dropout, to reduce compilation time. + // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. + // If return_softmax, set IsEvenMNConst to false to reduce number of templates + // If head dim > 128, set IsEvenMNConst to false to reduce number of templates + // If Is_local, set Is_causal to false + auto kernel = &flash_fwd_kernel; + // auto kernel = &flash_fwd_kernel; + // printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout)); + // auto kernel = &flash_fwd_kernel; + // int ctas_per_sm; + // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); + // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); + kernel<<>>(params); + }); + }); }); }); }); } + template void run_mha_fwd_hdim32(Flash_fwd_params ¶ms, cudaStream_t stream) { - constexpr int Headdim = 32; + constexpr static int Headdim = 32; BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] { run_flash_fwd, Is_dropout, Is_causal>(params, stream); @@ -64,7 +67,7 @@ void run_mha_fwd_hdim32(Flash_fwd_params ¶ms, cudaStream_t stream) { template void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) { - constexpr int Headdim = 64; + constexpr static int Headdim = 64; BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] { if constexpr(!Is_dropout) { @@ -86,7 +89,7 @@ void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) { template void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) { - constexpr int Headdim = 96; + constexpr static int Headdim = 96; // auto dprops = at::cuda::getCurrentDeviceProperties(); bool is_sm8x = true; // dprops->major == 8 && dprops->minor > 0; BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { @@ -112,7 +115,7 @@ void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) { template void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) { - constexpr int Headdim = 128; + constexpr static int Headdim = 128; // auto dprops = at::cuda::getCurrentDeviceProperties(); bool is_sm8x = true; // dprops->major == 8 && dprops->minor > 0; BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { @@ -149,7 +152,7 @@ void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) { template void run_mha_fwd_hdim160(Flash_fwd_params ¶ms, cudaStream_t stream) { - constexpr int Headdim = 160; + constexpr static int Headdim = 160; // auto dprops = at::cuda::getCurrentDeviceProperties(); bool is_sm8x = true; // dprops->major == 8 && dprops->minor > 0; BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { @@ -179,7 +182,7 @@ void run_mha_fwd_hdim160(Flash_fwd_params ¶ms, cudaStream_t stream) { template void run_mha_fwd_hdim192(Flash_fwd_params ¶ms, cudaStream_t stream) { - constexpr int Headdim = 192; + constexpr static int Headdim = 192; BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] { if constexpr(!Is_dropout) { @@ -198,7 +201,7 @@ void run_mha_fwd_hdim192(Flash_fwd_params ¶ms, cudaStream_t stream) { template void run_mha_fwd_hdim224(Flash_fwd_params ¶ms, cudaStream_t stream) { - constexpr int Headdim = 224; + constexpr static int Headdim = 224; int device; cudaGetDevice(&device); int max_smem_per_block; @@ -224,7 +227,7 @@ void run_mha_fwd_hdim224(Flash_fwd_params ¶ms, cudaStream_t stream) { template void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) { - constexpr int Headdim = 256; + constexpr static int Headdim = 256; int device; cudaGetDevice(&device); int max_smem_per_sm, max_smem_per_block; diff --git a/candle-flash-attn/kernels/kernel_traits.h b/candle-flash-attn/kernels/kernel_traits.h index 3468e4bf..f000ff24 100644 --- a/candle-flash-attn/kernels/kernel_traits.h +++ b/candle-flash-attn/kernels/kernel_traits.h @@ -91,17 +91,20 @@ struct Flash_fwd_kernel_traits : public Base { SmemLayoutAtomQ{}, Shape, Int>{})); + // This has to be kBlockN and not 8, otherwise we get wrong results for d=128 + using SmemLayoutAtomVtransposedNoSwizzle = Layout, Int>, + Stride<_1, Int>>; using SmemLayoutAtomVtransposed = decltype( - composition(Swizzle{}, - // This has to be kBlockN and not 8, otherwise we get wrong results for d=128 - Layout, Int>, - Stride<_1, Int>>{})); + composition(Swizzle{}, SmemLayoutAtomVtransposedNoSwizzle{})); using SmemLayoutVtransposed = decltype(tile_to_shape( SmemLayoutAtomVtransposed{}, Shape, Int>{})); // Maybe the VtransposeNoSwizzle just needs to have the right shape // And the strides don't matter? - using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn()); + using SmemLayoutVtransposedNoSwizzle = decltype(tile_to_shape( + SmemLayoutAtomVtransposedNoSwizzle{}, + Shape, Int>{})); + // using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn()); using SmemLayoutAtomO = decltype( composition(Swizzle{}, @@ -110,7 +113,8 @@ struct Flash_fwd_kernel_traits : public Base { using SmemLayoutO = decltype(tile_to_shape( SmemLayoutAtomO{}, Shape, Int>{})); - using SmemCopyAtomO = Copy_Atom; + using SmemCopyAtomO = Copy_Atom; + using SmemCopyAtomOaccum = Copy_Atom; static constexpr int kSmemQCount = size(SmemLayoutQ{}); static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2; @@ -138,11 +142,11 @@ struct Flash_fwd_kernel_traits : public Base { DefaultCopy >; using GmemTiledCopyQKV = decltype( - make_tiled_copy(Copy_Atom{}, + make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per read using GmemTiledCopyO = decltype( - make_tiled_copy(Copy_Atom{}, + make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per store static constexpr int kGmemThreadsPerRowP = kBlockN / kGmemElemsPerLoad; @@ -151,10 +155,30 @@ struct Flash_fwd_kernel_traits : public Base { Stride, _1>>; using GmemTiledCopyP = decltype( - make_tiled_copy(Copy_Atom{}, + make_tiled_copy(Copy_Atom{}, GmemLayoutAtomP{}, Layout>{})); // Val layout, 8 vals per store + using GmemLayoutAtomOaccum = std::conditional_t< + kBlockKSmem == 32, + Layout, // Thread layout, 8 threads per row + Stride< _8, _1>>, + Layout, // Thread layout, 16 threads per row + Stride< _16, _1>> + >; + using GmemTiledCopyOaccum = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomOaccum{}, + Layout>{})); // Val layout, 4 vals per store + using GmemLayoutAtomRotcossin = GmemLayoutAtom; + using GmemTiledCopyRotcossin = decltype( + make_tiled_copy(Copy_Atom, Element>{}, + GmemLayoutAtomRotcossin{}, + Layout>{})); // Val layout, 4 vals per load + using GmemTiledCopyRotcossinCont = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomRotcossin{}, + Layout>{})); // Val layout, 8 vals per load }; // Is_V_in_regs is an option to reduce smem usage, but will increase register pressue. @@ -223,16 +247,19 @@ struct Flash_bwd_kernel_traits : public Base { SmemLayoutAtomKV{}, make_shape(Int{}, Int{}))); + using SmemLayoutAtomKtransposedNoSwizzle = Layout, Int>, + Stride<_1, Int>>; using SmemLayoutAtomKtransposed = decltype( - composition(Swizzle{}, - Layout, Int>, - Stride<_1, Int>>{})); + composition(Swizzle{}, SmemLayoutAtomKtransposedNoSwizzle{})); using SmemLayoutKtransposed = decltype(tile_to_shape( SmemLayoutAtomKtransposed{}, make_shape(Int{}, Int{}))); // Maybe the KtransposeNoSwizzle just needs to have the right shape // And the strides don't matter? - using SmemLayoutKtransposedNoSwizzle = decltype(SmemLayoutKtransposed{}.layout_fn()); + using SmemLayoutKtransposedNoSwizzle = decltype(tile_to_shape( + SmemLayoutAtomKtransposedNoSwizzle{}, + make_shape(Int{}, Int{}))); + // using SmemLayoutKtransposedNoSwizzle = decltype(SmemLayoutKtransposed{}.layout_fn()); // TODO: generalize to other values of kBlockN // TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2 @@ -250,24 +277,30 @@ struct Flash_bwd_kernel_traits : public Base { using SmemLayoutPdS = decltype(tile_to_shape( SmemLayoutAtomPdS{}, make_shape(Int{}, Int{}))); + using SmemLayoutAtomPdStransposedNoSwizzle = Layout, Int>, + Stride<_1, Int>>; using SmemLayoutAtomPdStransposed = decltype( - composition(Swizzle{}, - Layout, Int>, - Stride<_1, Int>>{})); + composition(Swizzle{}, SmemLayoutAtomPdStransposedNoSwizzle{})); using SmemLayoutPdStransposed = decltype(tile_to_shape( SmemLayoutAtomPdStransposed{}, make_shape(Int{}, Int{}))); - using SmemLayoutPdStransposedNoSwizzle = decltype(SmemLayoutPdStransposed{}.layout_fn()); + using SmemLayoutPdStransposedNoSwizzle = decltype(tile_to_shape( + SmemLayoutAtomPdStransposedNoSwizzle{}, + make_shape(Int{}, Int{}))); + // using SmemLayoutPdStransposedNoSwizzle = decltype(SmemLayoutPdStransposed{}.layout_fn()); using SmemCopyAtomPdS = Copy_Atom; + using SmemLayoutAtomQdOtransposedNoSwizzle = Layout, Int>, + Stride<_1, Int>>; using SmemLayoutAtomQdOtransposed = decltype( - composition(Swizzle{}, - Layout, Int>, - Stride<_1, Int>>{})); + composition(Swizzle{}, SmemLayoutAtomQdOtransposedNoSwizzle{})); using SmemLayoutQdOtransposed = decltype(tile_to_shape( SmemLayoutAtomQdOtransposed{}, make_shape(Int{}, Int{}))); - using SmemLayoutQdOtransposedNoSwizzle = decltype(SmemLayoutQdOtransposed{}.layout_fn()); + using SmemLayoutQdOtransposedNoSwizzle = decltype(tile_to_shape( + SmemLayoutAtomQdOtransposedNoSwizzle{}, + make_shape(Int{}, Int{}))); + // using SmemLayoutQdOtransposedNoSwizzle = decltype(SmemLayoutQdOtransposed{}.layout_fn()); using SmemLayoutAtomdKV = decltype( composition(Swizzle{}, @@ -292,13 +325,11 @@ struct Flash_bwd_kernel_traits : public Base { static constexpr int kSmemdSCount = size(SmemLayoutPdS{}); static constexpr int kSmemPCount = size(SmemLayoutPdS{}); static constexpr int kSmemdQCount = size(SmemLayoutdQ{}); - static constexpr int kSmemdPsumCount = kBlockM; static constexpr int kSmemQdOSize = kSmemQdOCount * sizeof(Element); static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element); static constexpr int kSmemdSSize = kSmemdSCount * sizeof(Element); static constexpr int kSmemPSize = kSmemPCount * sizeof(Element); static constexpr int kSmemdQSize = kSmemdQCount * sizeof(Element); - static constexpr int kSmemdPsumSize = kSmemdPsumCount * sizeof(ElementAccum); static constexpr int kSmemSize = kSmemQdOSize + (!Is_V_in_regs ? kSmemKVSize + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize) diff --git a/candle-flash-attn/kernels/kernel_traits_sm90.h b/candle-flash-attn/kernels/kernel_traits_sm90.h new file mode 100644 index 00000000..e07f3839 --- /dev/null +++ b/candle-flash-attn/kernels/kernel_traits_sm90.h @@ -0,0 +1,159 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include "cute/algorithm/copy.hpp" + +#include "cutlass/cutlass.h" +#include "cutlass/layout/layout.h" +#include + +using namespace cute; + +template +struct Flash_kernel_traits_sm90 { + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + using Element = elem_type; + static constexpr bool Has_cp_async = true; +#else + using Element = cutlass::half_t; + static constexpr bool Has_cp_async = false; +#endif + + using ElementAccum = float; + using index_t = uint32_t; + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + using MMA_Atom_Arch = std::conditional_t< + std::is_same_v, + MMA_Atom, + MMA_Atom + >; + using ValLayoutMNK = Layout>; +#else + using MMA_Atom_Arch = MMA_Atom; + using ValLayoutMNK = Layout>; +#endif + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 + using SmemCopyAtom = Copy_Atom; + using SmemCopyAtomTransposed = Copy_Atom; +#else + using SmemCopyAtom = Copy_Atom; + using SmemCopyAtomTransposed = Copy_Atom; +#endif +}; + +template > +struct Flash_fwd_kernel_traits : public Base { + using Element = typename Base::Element; + using ElementAccum = typename Base::ElementAccum; + using index_t = typename Base::index_t; + static constexpr bool Has_cp_async = Base::Has_cp_async; + using SmemCopyAtom = typename Base::SmemCopyAtom; + using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; + + static constexpr bool Share_Q_K_smem = Share_Q_K_smem_; + static constexpr bool Is_Q_in_regs = Is_Q_in_regs_ || Share_Q_K_smem; + + // The number of threads. + static constexpr int kNWarps = kNWarps_; + static constexpr int kNThreads = kNWarps * 32; + + static constexpr int kBlockM = kBlockM_; + static constexpr int kBlockN = kBlockN_; + static constexpr int kHeadDim = kHeadDim_; + static_assert(kHeadDim % 32 == 0); + static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; + static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); + static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; + + using TiledMma = TiledMMA< + typename Base::MMA_Atom_Arch, + Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group + typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM + + using SmemLayoutAtomQ = decltype( + composition(Swizzle{}, + // This has to be kBlockKSmem, using kHeadDim gives wrong results for d=128 + Layout>, + Stride, _1>>{})); + using SmemLayoutQ = decltype(tile_to_shape( + SmemLayoutAtomQ{}, + Shape, Int>{})); + + using SmemLayoutKV = decltype(tile_to_shape( + SmemLayoutAtomQ{}, + Shape, Int>{})); + + using SmemLayoutAtomVtransposed = decltype( + composition(Swizzle{}, + // This has to be kBlockN and not 8, otherwise we get wrong results for d=128 + Layout, Int>, + Stride<_1, Int>>{})); + using SmemLayoutVtransposed = decltype(tile_to_shape( + SmemLayoutAtomVtransposed{}, + Shape, Int>{})); + // Maybe the VtransposeNoSwizzle just needs to have the right shape + // And the strides don't matter? + using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn()); + + using SmemLayoutAtomO = decltype( + composition(Swizzle{}, + Layout, Int>, + Stride, _1>>{})); + using SmemLayoutO = decltype(tile_to_shape( + SmemLayoutAtomO{}, + Shape, Int>{})); + using SmemCopyAtomO = Copy_Atom; + + static constexpr int kSmemQCount = size(SmemLayoutQ{}); + static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2; + static constexpr int kSmemQSize = kSmemQCount * sizeof(Element); + static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element); + static constexpr int kSmemSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize; + + static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); + static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); + // Using kBlockKSmem here is 6-10% faster than kBlockKGmem for d=128 because of bank conflicts. + // For example, for d=128, smem is split into 2 "pages", each page takes care of columns + // 0-63 and 64-127. If we have 16 threads per row for gmem read, when we write to smem, + // thread 0 - 7 will write to the first page and thread 8 - 15 will write to the second page, + // to the same banks. + static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; + static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); + using GmemLayoutAtom = Layout, Int>, + Stride, _1>>; + + // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading + // from the same address by the same threadblock. This is slightly faster. + using Gmem_copy_struct = std::conditional_t< + Has_cp_async, + SM80_CP_ASYNC_CACHEGLOBAL, + DefaultCopy + >; + using GmemTiledCopyQKV = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per read + using GmemTiledCopyO = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per store + static constexpr int kGmemThreadsPerRowP = kBlockN / kGmemElemsPerLoad; + static_assert(kNThreads % kGmemThreadsPerRowP == 0, "kNThreads must be a multiple of kGmemThreadsPerRowP"); + using GmemLayoutAtomP = Layout, Int>, + Stride, _1>>; + + using GmemTiledCopyP = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomP{}, + Layout>{})); // Val layout, 8 vals per store + +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/candle-flash-attn/kernels/softmax.h b/candle-flash-attn/kernels/softmax.h index 3e9a7b45..09a93f14 100644 --- a/candle-flash-attn/kernels/softmax.h +++ b/candle-flash-attn/kernels/softmax.h @@ -8,8 +8,7 @@ #include -#include -#include +#include #include "philox.cuh" #include "utils.h" @@ -117,15 +116,18 @@ inline __device__ void max_scale_exp2_sum(Tensor &tensor, Tens } template -inline __device__ void apply_mask(Tensor &tensor, const uint32_t max_seqlen_k) { +inline __device__ void apply_mask(Tensor &tensor, const int max_seqlen_k, + const int col_idx_offset_ = 0) { // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) static_assert(Layout::rank == 2, "Only support 2D Tensor"); - const uint32_t lane_id = threadIdx.x % 32; + const int lane_id = threadIdx.x % 32; + const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; #pragma unroll for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; #pragma unroll for (int j = 0; j < size<1, 0>(tensor); ++j) { - const uint32_t col_idx = nj * 8 + j + (lane_id % 4) * 2; + const int col_idx = col_idx_base + j; if (col_idx >= max_seqlen_k) { // Without the "make_coord" we get wrong results #pragma unroll @@ -137,30 +139,30 @@ inline __device__ void apply_mask(Tensor &tensor, const uint32_t } } -template -inline __device__ void apply_mask_causal(Tensor &tensor, const uint32_t col_idx_offset_, - const uint32_t max_seqlen_k, const uint32_t row_idx_offset_, - const uint32_t warp_row_stride) { +template +inline __device__ void apply_mask_local(Tensor &tensor, const int col_idx_offset_, + const int max_seqlen_k, const int row_idx_offset, + const int max_seqlen_q, const int warp_row_stride, + const int window_size_left, const int window_size_right) { // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) static_assert(Layout::rank == 2, "Only support 2D Tensor"); - const uint32_t lane_id = threadIdx.x % 32; - // const uint32_t row_idx_offset = row_idx_offset_ + lane_id / 4; - const uint32_t row_idx_offset = row_idx_offset_; - const uint32_t col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; + const int lane_id = threadIdx.x % 32; + const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; #pragma unroll for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { - const uint32_t row_idx_base = row_idx_offset + mi * warp_row_stride; + const int row_idx_base = row_idx_offset + mi * warp_row_stride; #pragma unroll for (int i = 0; i < size<0, 0>(tensor); ++i) { - const uint32_t row_idx = row_idx_base + i * 8; - const uint32_t col_idx_limit = std::min(max_seqlen_k, row_idx + 1); + const int row_idx = row_idx_base + i * 8; + const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left); + const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right); #pragma unroll for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { - const uint32_t col_idx_base = col_idx_offset + nj * 8; + const int col_idx_base = col_idx_offset + nj * 8; #pragma unroll for (int j = 0; j < size<1, 0>(tensor); ++j) { - const uint32_t col_idx = col_idx_base + j; - if (col_idx >= col_idx_limit) { + const int col_idx = col_idx_base + j; + if (col_idx >= col_idx_limit_right || (HasWSLeft && col_idx < col_idx_limit_left)) { tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; } } @@ -174,10 +176,19 @@ inline __device__ void apply_mask_causal(Tensor &tensor, const u } } +template +inline __device__ void apply_mask_causal(Tensor &tensor, const int col_idx_offset_, + const int max_seqlen_k, const int row_idx_offset, + const int max_seqlen_q, const int warp_row_stride) { + // Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0 + apply_mask_local(tensor, col_idx_offset_, max_seqlen_k, row_idx_offset, + max_seqlen_q, warp_row_stride, -1, 0); +} + template inline __device__ void apply_mask_causal_w_idx( Tensor &tensor, Tensor const &idx_rowcol, - const uint32_t col_idx_offset_, const uint32_t max_seqlen_k, const uint32_t row_idx_offset_) + const int col_idx_offset_, const int max_seqlen_k, const int row_idx_offset) { // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) static_assert(Layout0::rank == 2, "Only support 2D Tensor"); @@ -186,7 +197,7 @@ inline __device__ void apply_mask_causal_w_idx( CUTE_STATIC_ASSERT_V(size<1>(tensor) == size<1>(idx_rowcol)); #pragma unroll for (int mi = 0; mi < size<0>(tensor); ++mi) { - const uint32_t col_idx_limit = std::min(max_seqlen_k, 1 + row_idx_offset_ + get<0>(idx_rowcol(mi, 0))); + const int col_idx_limit = std::min(max_seqlen_k, 1 + row_idx_offset + get<0>(idx_rowcol(mi, 0))); #pragma unroll for (int ni = 0; ni < size<1, 1>(tensor); ++ni) { if (col_idx_offset_ + get<1>(idx_rowcol(0, ni)) >= col_idx_limit) { @@ -204,8 +215,8 @@ inline __device__ void apply_mask_causal_w_idx( template inline __device__ void apply_dropout(Tensor &tensor, uint8_t p_dropout_in_uint8_t, unsigned long long seed, unsigned long long offset, - uint32_t block_row_start, uint32_t block_col_start, - uint32_t block_row_stride) { + int block_row_start, int block_col_start, + int block_row_stride) { // tensor has shape (8, MMA_M, MMA_N / 2) using T = typename Engine::value_type; auto encode_dropout = [](bool keep, T val) { diff --git a/candle-flash-attn/kernels/utils.h b/candle-flash-attn/kernels/utils.h index 2221a2fa..6fb39dc4 100644 --- a/candle-flash-attn/kernels/utils.h +++ b/candle-flash-attn/kernels/utils.h @@ -87,46 +87,6 @@ inline __device__ uint32_t convert_relu2(const float2 x) { //////////////////////////////////////////////////////////////////////////////////////////////////// -template -inline __device__ float2 half2_unpack(uint32_t a); - -template <> -inline __device__ float2 half2_unpack<__half>(uint32_t a) { - return __half22float2(reinterpret_cast<__half2 (&)>(a)); -} - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 -template <> -inline __device__ float2 half2_unpack<__nv_bfloat16>(uint32_t a) { - return __bfloat1622float2(reinterpret_cast<__nv_bfloat162 (&)>(a)); -} -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// Convert two half2's or bf162's into float, then take their dot product. -template -inline __device__ float hfma2_to_float(const uint32_t a, const uint32_t b) { - float2 af = flash::half2_unpack(a); - float2 bf = flash::half2_unpack(b); - return af.x * bf.x + af.y * bf.y; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// Converted two vectors of 8 half's or bf16's into float, then take their dot product. -template -inline __device__ float hmulsum8(const uint4 a, const uint4 b) { - float sum; - sum = flash::hfma2_to_float(a.x, b.x); - sum += flash::hfma2_to_float(a.y, b.y); - sum += flash::hfma2_to_float(a.z, b.z); - sum += flash::hfma2_to_float(a.w, b.w); - return sum; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - template struct MaxOp { __device__ inline T operator()(T const & x, T const & y) { return x > y ? x : y; } @@ -173,10 +133,12 @@ static __device__ inline T run(T x, Operator &op) { template + typename TiledMma, typename TiledCopyA, typename TiledCopyB, + typename ThrCopyA, typename ThrCopyB> inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA, Tensor4 const& tCsB, TiledMma tiled_mma, - TiledCopy0 smem_thr_copy_A, TiledCopy1 smem_thr_copy_B) { + TiledCopyA smem_tiled_copy_A, TiledCopyB smem_tiled_copy_B, + ThrCopyA smem_thr_copy_A, ThrCopyB smem_thr_copy_B) { CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K @@ -184,13 +146,13 @@ inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N - if (!A_in_regs) { copy(smem_thr_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); } - if (!B_in_regs) { copy(smem_thr_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); } + if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); } + if (!B_in_regs) { cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); } #pragma unroll for (int i = 0; i < size<2>(tCrA); ++i) { if (i < size<2>(tCrA) - 1) { - if (!A_in_regs) { copy(smem_thr_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1)); } - if (!B_in_regs) { copy(smem_thr_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); } + if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1)); } + if (!B_in_regs) { cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); } } cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); } @@ -199,19 +161,20 @@ inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 //////////////////////////////////////////////////////////////////////////////////////////////////// template + typename TiledMma, typename TiledCopy, typename ThrCopy> inline __device__ void gemm_A_in_regs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB, - TiledMma tiled_mma, TiledCopy smem_thr_copy_B) { + TiledMma tiled_mma, TiledCopy smem_tiled_copy_B, + ThrCopy smem_thr_copy_B) { CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N - copy(smem_thr_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); + cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); #pragma unroll for (int i = 0; i < size<2>(tCrA); ++i) { if (i < size<2>(tCrA) - 1) { - copy(smem_thr_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); + cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); } cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); } @@ -225,7 +188,10 @@ inline __device__ auto convert_layout_acc_rowcol(Layout acc_layout) { static_assert(decltype(size<0>(acc_layout))::value == 4); static_assert(decltype(rank(acc_layout))::value == 3); auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N) - return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l))); + // TD [2023-08-13]: Idk why but get<0, 1>(l) doesn't work for Cutlass 3.2, I'm getting + // "int_tuple.hpp(74): error: conversion to inaccessible base class" + // return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l))); + return make_layout(make_layout(get<1>(get<0>(l)), get<1>(l)), make_layout(get<0>(get<0>(l)), get<2>(l))); }; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -241,9 +207,13 @@ inline __device__ auto convert_layout_rowcol_Aregs(Layout rowcol_layout) { static_assert(mma_shape_K == 8 || mma_shape_K == 16); constexpr int MMA_N_divisor = mma_shape_K == 8 ? 1 : 2; auto l = logical_divide(rowcol_layout, Shape>>{}); // ((2, MMA_M), (2, (2, MMA_N / 2))) - return make_layout(make_layout(get<1, 0>(l), get<0, 0>(l), get<1, 1, 0>(l)), - get<0, 1>(l), - get<1, 1, 1>(l)); + // TD [2023-08-13]: Same error as above on Cutlass 3.2 + // return make_layout(make_layout(get<1, 0>(l), get<0, 0>(l), get<1, 1, 0>(l)), + // get<0, 1>(l), + // get<1, 1, 1>(l)); + return make_layout(make_layout(get<0>(get<1>(l)), get<0>(get<0>(l)), get<0>(get<1>(get<1>(l)))), + get<1>(get<0>(l)), + get<1>(get<1>(get<1>(l)))); }; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -319,9 +289,9 @@ void cp_async_wait() { template -inline __device__ void copy(TiledCopy thr_copy, Tensor const &S, +inline __device__ void copy(TiledCopy tiled_copy, Tensor const &S, Tensor &D, Tensor const &identity_MN, - Tensor const &predicate_K, int max_MN=0) { + Tensor const &predicate_K, const int max_MN=0) { CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA @@ -335,13 +305,13 @@ inline __device__ void copy(TiledCopy thr_copy, Tensor const & #pragma unroll for (int k = 0; k < size<2>(S); ++k) { if (Is_even_K || predicate_K(k)) { - copy(thr_copy, S(_, m, k), D(_, m, k)); + cute::copy(tiled_copy, S(_, m, k), D(_, m, k)); } else if (Clear_OOB_K) { - clear(D(_, m, k)); + cute::clear(D(_, m, k)); } } } else if (Clear_OOB_MN) { - clear(D(_, m, _)); + cute::clear(D(_, m, _)); } } // TD [2023-04-13]: Strange that the code below can cause race condition. @@ -350,7 +320,7 @@ inline __device__ void copy(TiledCopy thr_copy, Tensor const & // #pragma unroll // for (int m = 0; m < size<1>(S); ++m) { // if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { - // copy(thr_copy, S(_, m, _), D(_, m, _)); + // copy(tiled_copy, S(_, m, _), D(_, m, _)); // } else if (Clear_OOB_MN) { // clear(D(_, m, _)); // } @@ -362,7 +332,7 @@ inline __device__ void copy(TiledCopy thr_copy, Tensor const & // #pragma unroll // for (int m = 0; m < size<1>(S); ++m) { // if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { - // copy(thr_copy, S(_, m, k), D(_, m, k)); + // copy(tiled_copy, S(_, m, k), D(_, m, k)); // } else if (Clear_OOB_MN) { // clear(D(_, m, k)); // } diff --git a/candle-flash-attn/src/ffi.rs b/candle-flash-attn/src/ffi.rs index 90f34e43..ca65520b 100644 --- a/candle-flash-attn/src/ffi.rs +++ b/candle-flash-attn/src/ffi.rs @@ -7,6 +7,8 @@ extern "C" { v_ptr: *const c_void, o_ptr: *const c_void, softmax_lse_ptr: *const c_void, + alibi_slopes_ptr: *const c_void, + cu_seqlens_q_ptr: *const i32, cu_seqlens_k_ptr: *const i32, @@ -14,6 +16,7 @@ extern "C" { k_batch_stride: u32, v_batch_stride: u32, o_batch_stride: u32, + alibi_slopes_batch_stride: u32, q_row_stride: u32, k_row_stride: u32, @@ -37,8 +40,11 @@ extern "C" { seqlen_q_rounded: u32, seqlen_k_rounded: u32, - is_causal: c_int, is_bf16: c_int, + is_causal: c_int, + + window_size_left: c_int, + window_size_right: c_int, ); } diff --git a/candle-flash-attn/src/lib.rs b/candle-flash-attn/src/lib.rs index 3395bd0d..21a06b5e 100644 --- a/candle-flash-attn/src/lib.rs +++ b/candle-flash-attn/src/lib.rs @@ -3,12 +3,14 @@ mod ffi; use candle::backend::BackendStorage; use candle::cuda_backend::cudarc::driver::DevicePtr; use candle::cuda_backend::WrapErr; -use candle::{CpuStorage, Layout, Result, Shape, Tensor}; +use candle::{CpuStorage, DType, Layout, Result, Shape, Tensor}; use half::{bf16, f16}; pub struct FlashAttn { pub softmax_scale: f32, - pub causal: bool, + pub alibi_slopes: Option, + pub window_size_left: Option, + pub window_size_right: Option, } fn round_multiple(x: usize, m: usize) -> usize { @@ -85,6 +87,51 @@ impl FlashAttn { candle::bail!("number of k/v heads {num_heads_k} must divide number of heads in query {num_heads}") } + let alibi_slopes_ptr = if let Some(alibi_slopes) = &self.alibi_slopes { + if alibi_slopes.dtype() != DType::F32 { + candle::bail!( + "DType mismatch alibi_slopes {:?}, expected {:?}", + alibi_slopes.dtype(), + DType::F32 + ); + } + + let (alibi_slopes, alibi_slopes_layout) = alibi_slopes.storage_and_layout(); + + if num_heads != alibi_slopes_layout.shape().dims1()? { + candle::bail!( + "shape mismatch alibi_slopes {:?}, expected {:?}", + alibi_slopes_layout.shape(), + (num_heads) + ); + } + + let alibi_slopes = match &*alibi_slopes { + candle::Storage::Cuda(c) => c.as_cuda_slice::()?, + _ => candle::bail!("alibi_slopes must be a cuda tensor"), + }; + + let alibi_slopes = alibi_slopes.slice(alibi_slopes_layout.start_offset()..); + + *alibi_slopes.device_ptr() as *const core::ffi::c_void + } else { + std::ptr::null() + }; + + // if window_size_left > self.max_seqlen_k or None => -1 + let mut window_size_left = self + .window_size_left + .filter(|v| v <= &seqlen_k) + .map(|v| v as i32) + .unwrap_or(-1); + + // if window_size_right > self.max_seqlen_k or None => -1 + let mut window_size_right = self + .window_size_right + .filter(|v| v <= &seqlen_k) + .map(|v| v as i32) + .unwrap_or(-1); + let head_size = round_multiple(head_size_og, 8); let head_size_rounded = round_multiple(head_size, 32); let seqlen_q_rounded = round_multiple(seqlen_q, 128); @@ -94,9 +141,22 @@ impl FlashAttn { let dst = unsafe { dev.alloc::(elem_count) }.w()?; let softmax_lse = dev.alloc_zeros::(b_sz * num_heads * seqlen_q).w()?; - let causal = if self.causal { 1 } else { 0 }; let is_bf16 = if is_bf16 { 1 } else { 0 }; + // Causal is the special case where window_size_right == 0 and window_size_left < 0. + // Local is the more general case where window_size_right >= 0 or window_size_left >= 0. + let is_causal = if window_size_left < 0 && window_size_right == 0 { + 1 + } else { + 0 + }; + if window_size_left < 0 && window_size_right >= 0 { + window_size_left = seqlen_k as i32; + } + if window_size_left >= 0 && window_size_right < 0 { + window_size_right = seqlen_k as i32; + } + unsafe { let q_ptr = *q.device_ptr() as *const core::ffi::c_void; let k_ptr = *k.device_ptr() as *const core::ffi::c_void; @@ -109,12 +169,14 @@ impl FlashAttn { v_ptr, dst_ptr, softmax_lse_ptr, + /* alibi_slopes_ptr */ alibi_slopes_ptr, /* cu_seqlens_q_ptr */ std::ptr::null(), /* cu_seqlens_k_ptr */ std::ptr::null(), /* q_batch_stride */ q_stride[0] as u32, /* k_batch_stride */ k_stride[0] as u32, /* v_batch_stride */ v_stride[0] as u32, /* o_batch_stride */ o_stride[0] as u32, + /* alibi_slopes_batch_stride */ 0, /* q_row_stride */ q_stride[q_rank - 3] as u32, /* k_row_stride */ k_stride[k_rank - 3] as u32, /* v_row_stride */ v_stride[v_rank - 3] as u32, @@ -133,8 +195,10 @@ impl FlashAttn { /* seqlen_k */ seqlen_k as u32, /* seqlen_q_rounded */ seqlen_q_rounded as u32, /* seqlen_k_rounded */ seqlen_k_rounded as u32, - /* is_causal */ causal, /* is_bf16 */ is_bf16, + /* is_causal */ is_causal, + /* window_size_left */ window_size_left, + /* window_size_right */ window_size_right, ) } @@ -197,20 +261,137 @@ pub fn flash_attn( softmax_scale: f32, causal: bool, ) -> Result { + let window_size_left = None; + let window_size_right = if causal { Some(0) } else { None }; + let op = FlashAttn { softmax_scale, - causal, + alibi_slopes: None, + window_size_left, + window_size_right, + }; + q.apply_op3(k, v, op) +} + +/// Flash-attention v2 layer. +/// +/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`. +/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads +/// than q, the number of heads in k and v has to be divisible by the number of heads in q. +/// +/// # Arguments +/// +/// * `q` - Query tensor with shape `(batch, seq_len_q, num_heads_q, head_size)`. +/// * `k` - Key tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`. +/// * `v` - Value tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`. +/// * `window_size_left` - Limit left attention to value tokens. +/// * `window_size_right` - Limit right attention to value tokens. +/// +/// # Causal mask +/// +/// `window_size_left=None` with `window_size_right=Some(0)` applies a causal mask to the result +/// of `Q @ K^T` +/// +/// The resulting tensor has dimensions `(batch, seq_len_q, num_heads_q, head_size)`. +pub fn flash_attn_windowed( + q: &Tensor, + k: &Tensor, + v: &Tensor, + softmax_scale: f32, + window_size_left: Option, + window_size_right: Option, +) -> Result { + let op = FlashAttn { + softmax_scale, + alibi_slopes: None, + window_size_left, + window_size_right, + }; + q.apply_op3(k, v, op) +} + +/// Flash-attention v2 layer. +/// +/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`. +/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads +/// than q, the number of heads in k and v has to be divisible by the number of heads in q. +/// +/// # Arguments +/// +/// * `q` - Query tensor with shape `(batch, seq_len_q, num_heads_q, head_size)`. +/// * `k` - Key tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`. +/// * `v` - Value tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`. +/// * `alibi_slopes` - Alibi slopes tensor with shape `(num_heads_q)`. +/// +/// The resulting tensor has dimensions `(batch, seq_len_q, num_heads_q, head_size)`. +pub fn flash_attn_alibi( + q: &Tensor, + k: &Tensor, + v: &Tensor, + alibi_slopes: &Tensor, + softmax_scale: f32, + causal: bool, +) -> Result { + let window_size_left = None; + let window_size_right = if causal { Some(0) } else { None }; + + let op = FlashAttn { + softmax_scale, + alibi_slopes: Some(alibi_slopes.clone()), + window_size_left, + window_size_right, + }; + q.apply_op3(k, v, op) +} + +/// Flash-attention v2 layer. +/// +/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`. +/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads +/// than q, the number of heads in k and v has to be divisible by the number of heads in q. +/// +/// # Arguments +/// +/// * `q` - Query tensor with shape `(batch, seq_len_q, num_heads_q, head_size)`. +/// * `k` - Key tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`. +/// * `v` - Value tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`. +/// * `alibi_slopes` - Alibi slopes tensor with shape `(num_heads_q)`. +/// * `window_size_left` - Limit left attention to value tokens. +/// * `window_size_right` - Limit right attention to value tokens. +/// +/// # Causal mask +/// +/// `window_size_left=None` with `window_size_right=Some(0)` applies a causal mask to the result +/// of `Q @ K^T` +/// +/// The resulting tensor has dimensions `(batch, seq_len_q, num_heads_q, head_size)`. +pub fn flash_attn_alibi_windowed( + q: &Tensor, + k: &Tensor, + v: &Tensor, + alibi_slopes: &Tensor, + softmax_scale: f32, + window_size_left: Option, + window_size_right: Option, +) -> Result { + let op = FlashAttn { + softmax_scale, + alibi_slopes: Some(alibi_slopes.clone()), + window_size_left, + window_size_right, }; q.apply_op3(k, v, op) } struct FlashAttnVarLen { - softmax_scale: f32, - causal: bool, - max_seqlen_q: usize, - max_seqlen_k: usize, - seqlens_q: Tensor, - seqlens_k: Tensor, + pub softmax_scale: f32, + pub max_seqlen_q: usize, + pub max_seqlen_k: usize, + pub seqlens_q: Tensor, + pub seqlens_k: Tensor, + pub alibi_slopes: Option, + pub window_size_left: Option, + pub window_size_right: Option, } impl FlashAttnVarLen { @@ -311,7 +492,54 @@ impl FlashAttnVarLen { if nseqlens_k != nseqlens_q { candle::bail!("seqlens_q and seqlens_k should have the same number of elements {nseqlens_q} <> {nseqlens_k}") } + let batch_size = nseqlens_q - 1; + + let alibi_slopes_ptr = if let Some(alibi_slopes) = &self.alibi_slopes { + if alibi_slopes.dtype() != DType::F32 { + candle::bail!( + "DType mismatch alibi_slopes {:?}, expected {:?}", + alibi_slopes.dtype(), + DType::F32 + ); + } + + let (alibi_slopes, alibi_slopes_layout) = alibi_slopes.storage_and_layout(); + + if num_heads != alibi_slopes_layout.shape().dims1()? { + candle::bail!( + "shape mismatch alibi_slopes {:?}, expected {:?}", + alibi_slopes_layout.shape(), + (num_heads) + ); + } + + let alibi_slopes = match &*alibi_slopes { + candle::Storage::Cuda(c) => c.as_cuda_slice::()?, + _ => candle::bail!("alibi_slopes must be a cuda tensor"), + }; + + let alibi_slopes = alibi_slopes.slice(alibi_slopes_layout.start_offset()..); + + *alibi_slopes.device_ptr() as *const core::ffi::c_void + } else { + std::ptr::null() + }; + + // if window_size_left > self.max_seqlen_k or None => -1 + let mut window_size_left = self + .window_size_left + .filter(|v| v <= &self.max_seqlen_k) + .map(|v| v as i32) + .unwrap_or(-1); + + // if window_size_right > self.max_seqlen_k or None => -1 + let mut window_size_right = self + .window_size_right + .filter(|v| v <= &self.max_seqlen_k) + .map(|v| v as i32) + .unwrap_or(-1); + let head_size = round_multiple(head_size_og, 8); let head_size_rounded = round_multiple(head_size, 32); let seqlen_q_rounded = round_multiple(self.max_seqlen_q, 128); @@ -323,9 +551,22 @@ impl FlashAttnVarLen { .alloc_zeros::(batch_size * num_heads * self.max_seqlen_q) .w()?; - let causal = if self.causal { 1 } else { 0 }; let is_bf16 = if is_bf16 { 1 } else { 0 }; + // Causal is the special case where window_size_right == 0 and window_size_left < 0. + // Local is the more general case where window_size_right >= 0 or window_size_left >= 0. + let is_causal = if window_size_left < 0 && window_size_right == 0 { + 1 + } else { + 0 + }; + if window_size_left < 0 && window_size_right >= 0 { + window_size_left = self.max_seqlen_k as i32; + } + if window_size_left >= 0 && window_size_right < 0 { + window_size_right = self.max_seqlen_k as i32; + } + unsafe { let q_ptr = *q.device_ptr() as *const core::ffi::c_void; let k_ptr = *k.device_ptr() as *const core::ffi::c_void; @@ -340,12 +581,14 @@ impl FlashAttnVarLen { v_ptr, dst_ptr, softmax_lse_ptr, + /* alibi_slopes_ptr */ alibi_slopes_ptr, /* cu_seqlens_q_ptr */ seqlens_q_ptr, /* cu_seqlens_k_ptr */ seqlens_k_ptr, /* q_batch_stride */ 0, /* k_batch_stride */ 0, /* v_batch_stride */ 0, /* o_batch_stride */ 0, + /* alibi_slopes_batch_stride */ 0, /* q_row_stride */ q_stride[q_rank - 3] as u32, /* k_row_stride */ k_stride[k_rank - 3] as u32, /* v_row_stride */ v_stride[v_rank - 3] as u32, @@ -364,8 +607,10 @@ impl FlashAttnVarLen { /* seqlen_k */ self.max_seqlen_k as u32, /* seqlen_q_rounded */ seqlen_q_rounded as u32, /* seqlen_k_rounded */ seqlen_k_rounded as u32, - /* is_causal */ causal, /* is_bf16 */ is_bf16, + /* is_causal */ is_causal, + /* window_size_left */ window_size_left, + /* window_size_right */ window_size_right, ) } @@ -440,13 +685,176 @@ pub fn flash_attn_varlen( softmax_scale: f32, causal: bool, ) -> Result { + let window_size_left = None; + let window_size_right = if causal { Some(0) } else { None }; + let op = FlashAttnVarLen { softmax_scale, - causal, max_seqlen_q, max_seqlen_k, seqlens_q: seqlens_q.clone(), seqlens_k: seqlens_k.clone(), + alibi_slopes: None, + window_size_left, + window_size_right, + }; + q.apply_op3(k, v, op) +} + +#[allow(clippy::too_many_arguments)] +/// Flash-attention v2 layer with variable-length batching. +/// +/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`. +/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads +/// than q, the number of heads in k and v has to be divisible by the number of heads in q. +/// +/// # Arguments +/// +/// * `q` - Query tensor with shape `(total_q, num_heads_q, head_size)`. +/// * `k` - Key tensor with shape `(total_kv, num_heads_kv, head_size)`. +/// * `v` - Value tensor with shape `(total_kv, num_heads_kv, head_size)`. +/// * `seqlens_q` - The cumulative lengths of the sequences in the batch, used to index in q. +/// * `seqlens_k` - The cumulative lengths of the sequences in the batch, used to index in k and v. +/// * `max_seqlen_q` - The maximum query sequence length for q in the batch. +/// * `max_seqlen_k` - The maximum query sequence length for k and v in the batch. +/// * `window_size_left` - Limit left attention to value tokens. +/// * `window_size_right` - Limit right attention to value tokens. +/// +/// `seqlens_q` and `seqlens_k` contain `batch_size + 1` elements, typically `0`, `seqlen_1`, +/// `seqlen_1 + seqlen_2`, etc. +/// +/// The resulting tensor has dimensions `(total_q, num_heads_q, head_size)`. +/// +/// # Causal mask +/// +/// `window_size_left=None` with `window_size_right=Some(0)` applies a causal mask to the result +/// of `Q @ K^T` +pub fn flash_attn_varlen_windowed( + q: &Tensor, + k: &Tensor, + v: &Tensor, + seqlens_q: &Tensor, + seqlens_k: &Tensor, + max_seqlen_q: usize, + max_seqlen_k: usize, + softmax_scale: f32, + window_size_left: Option, + window_size_right: Option, +) -> Result { + let op = FlashAttnVarLen { + softmax_scale, + max_seqlen_q, + max_seqlen_k, + seqlens_q: seqlens_q.clone(), + seqlens_k: seqlens_k.clone(), + alibi_slopes: None, + window_size_left, + window_size_right, + }; + q.apply_op3(k, v, op) +} + +#[allow(clippy::too_many_arguments)] +/// Flash-attention v2 layer with variable-length batching. +/// +/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`. +/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads +/// than q, the number of heads in k and v has to be divisible by the number of heads in q. +/// +/// # Arguments +/// +/// * `q` - Query tensor with shape `(total_q, num_heads_q, head_size)`. +/// * `k` - Key tensor with shape `(total_kv, num_heads_kv, head_size)`. +/// * `v` - Value tensor with shape `(total_kv, num_heads_kv, head_size)`. +/// * `alibi_slopes` - Alibi slopes tensor with shape `(num_heads_q)`. +/// * `seqlens_q` - The cumulative lengths of the sequences in the batch, used to index in q. +/// * `seqlens_k` - The cumulative lengths of the sequences in the batch, used to index in k and v. +/// * `max_seqlen_q` - The maximum query sequence length for q in the batch. +/// * `max_seqlen_k` - The maximum query sequence length for k and v in the batch. +/// +/// `seqlens_q` and `seqlens_k` contain `batch_size + 1` elements, typically `0`, `seqlen_1`, +/// `seqlen_1 + seqlen_2`, etc. +/// +/// The resulting tensor has dimensions `(total_q, num_heads_q, head_size)`. +pub fn flash_attn_varlen_alibi( + q: &Tensor, + k: &Tensor, + v: &Tensor, + alibi_slopes: &Tensor, + seqlens_q: &Tensor, + seqlens_k: &Tensor, + max_seqlen_q: usize, + max_seqlen_k: usize, + softmax_scale: f32, + causal: bool, +) -> Result { + let window_size_left = None; + let window_size_right = if causal { Some(0) } else { None }; + + let op = FlashAttnVarLen { + softmax_scale, + max_seqlen_q, + max_seqlen_k, + seqlens_q: seqlens_q.clone(), + seqlens_k: seqlens_k.clone(), + alibi_slopes: Some(alibi_slopes.clone()), + window_size_left, + window_size_right, + }; + q.apply_op3(k, v, op) +} + +#[allow(clippy::too_many_arguments)] +/// Flash-attention v2 layer with variable-length batching. +/// +/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`. +/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads +/// than q, the number of heads in k and v has to be divisible by the number of heads in q. +/// +/// # Arguments +/// +/// * `q` - Query tensor with shape `(total_q, num_heads_q, head_size)`. +/// * `k` - Key tensor with shape `(total_kv, num_heads_kv, head_size)`. +/// * `v` - Value tensor with shape `(total_kv, num_heads_kv, head_size)`. +/// * `alibi_slopes` - Alibi slopes tensor with shape `(num_heads_q)`. +/// * `seqlens_q` - The cumulative lengths of the sequences in the batch, used to index in q. +/// * `seqlens_k` - The cumulative lengths of the sequences in the batch, used to index in k and v. +/// * `max_seqlen_q` - The maximum query sequence length for q in the batch. +/// * `max_seqlen_k` - The maximum query sequence length for k and v in the batch. +/// * `window_size_left` - Limit left attention to value tokens. +/// * `window_size_right` - Limit right attention to value tokens. +/// +/// `seqlens_q` and `seqlens_k` contain `batch_size + 1` elements, typically `0`, `seqlen_1`, +/// `seqlen_1 + seqlen_2`, etc. +/// +/// The resulting tensor has dimensions `(total_q, num_heads_q, head_size)`. +/// +/// # Causal mask +/// +/// `window_size_left=None` with `window_size_right=Some(0)` applies a causal mask to the result +/// of `Q @ K^T` +pub fn flash_attn_varlen_alibi_windowed( + q: &Tensor, + k: &Tensor, + v: &Tensor, + alibi_slopes: &Tensor, + seqlens_q: &Tensor, + seqlens_k: &Tensor, + max_seqlen_q: usize, + max_seqlen_k: usize, + softmax_scale: f32, + window_size_left: Option, + window_size_right: Option, +) -> Result { + let op = FlashAttnVarLen { + softmax_scale, + max_seqlen_q, + max_seqlen_k, + seqlens_q: seqlens_q.clone(), + seqlens_k: seqlens_k.clone(), + alibi_slopes: Some(alibi_slopes.clone()), + window_size_left, + window_size_right, }; q.apply_op3(k, v, op) } From 84250bf52f58528cf59dca3b82effd9f07a13cc7 Mon Sep 17 00:00:00 2001 From: optman Date: Sat, 6 Jan 2024 18:43:01 +0800 Subject: [PATCH 3/8] fix index_pos bug when kv cache is disabled. (#1517) * fix index_pos bug when kv cache is disabled * Tweak the fix. --------- Co-authored-by: laurent --- candle-examples/examples/llama/main.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs index 46f474bb..251c184b 100644 --- a/candle-examples/examples/llama/main.rs +++ b/candle-examples/examples/llama/main.rs @@ -165,14 +165,14 @@ fn main() -> Result<()> { let mut index_pos = 0; let mut token_generated = 0; for index in 0..args.sample_len { - let context_size = if cache.use_kv_cache && index > 0 { - 1 + let (context_size, context_index) = if cache.use_kv_cache && index > 0 { + (1, index_pos) } else { - tokens.len() + (tokens.len(), 0) }; let ctxt = &tokens[tokens.len().saturating_sub(context_size)..]; let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?; - let logits = llama.forward(&input, index_pos)?; + let logits = llama.forward(&input, context_index)?; let logits = logits.squeeze(0)?; let logits = if args.repeat_penalty == 1. { logits From b4cb982e498fc121992e7c03d00d04755a66001f Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Sun, 7 Jan 2024 12:04:14 +0100 Subject: [PATCH 4/8] Simplifying our internal cargo dependencies. (#1529) --- Cargo.toml | 8 ++++++++ candle-book/Cargo.toml | 10 +++++----- candle-core/Cargo.toml | 4 ++-- candle-datasets/Cargo.toml | 4 ++-- candle-examples/Cargo.toml | 12 ++++++------ candle-nn/Cargo.toml | 4 ++-- candle-onnx/Cargo.toml | 1 - candle-pyo3/Cargo.toml | 6 +++--- candle-transformers/Cargo.toml | 6 +++--- candle-wasm-examples/bert/Cargo.toml | 6 +++--- candle-wasm-examples/blip/Cargo.toml | 6 +++--- candle-wasm-examples/llama2-c/Cargo.toml | 6 +++--- candle-wasm-examples/phi/Cargo.toml | 6 +++--- candle-wasm-examples/segment-anything/Cargo.toml | 6 +++--- candle-wasm-examples/t5/Cargo.toml | 6 +++--- candle-wasm-examples/whisper/Cargo.toml | 6 +++--- candle-wasm-examples/yolo/Cargo.toml | 4 ++-- candle-wasm-tests/Cargo.toml | 2 +- 18 files changed, 55 insertions(+), 48 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 7d61cd74..3d66a02f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -31,6 +31,14 @@ license = "MIT OR Apache-2.0" accelerate-src = { version = "0.3.2" } anyhow = { version = "1", features = ["backtrace"] } byteorder = "1.4.3" +candle = { path = "./candle-core", package = "candle-core" } +candle-datasets = { path = "./candle-datasets" } +candle-flash-attn = { path = "./candle-flash-attn" } +candle-kernels = { path = "./candle-kernels" } +candle-metal-kernels = { path = "./candle-metal-kernels" } +candle-nn = { path = "./candle-nn" } +candle-onnx = { path = "./candle-onnx" } +candle-transformers = { path = "./candle-transformers" } clap = { version = "4.2.4", features = ["derive"] } criterion = { version = "0.5.1", default-features=false } cudarc = { version = "0.9.14", features = ["f16"] } diff --git a/candle-book/Cargo.toml b/candle-book/Cargo.toml index e28e6623..5ccda31e 100644 --- a/candle-book/Cargo.toml +++ b/candle-book/Cargo.toml @@ -11,11 +11,11 @@ readme = "README.md" [dependencies] accelerate-src = { workspace = true, optional = true } -candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" } -candle-datasets = { path = "../candle-datasets", version = "0.3.3" } -candle-nn = { path = "../candle-nn", version = "0.3.3" } -candle-transformers = { path = "../candle-transformers", version = "0.3.3" } -candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.3", optional = true } +candle = { workspace = true } +candle-datasets = { workspace = true } +candle-nn = { workspace = true } +candle-transformers = { workspace = true } +candle-flash-attn = { workspace = true, optional = true } safetensors = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml index 91655f57..97857a6b 100644 --- a/candle-core/Cargo.toml +++ b/candle-core/Cargo.toml @@ -12,8 +12,8 @@ readme = "README.md" [dependencies] accelerate-src = { workspace = true, optional = true } byteorder = { workspace = true } -candle-kernels = { path = "../candle-kernels", version = "0.3.3", optional = true } -candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.3", optional = true } +candle-kernels = { workspace = true, optional = true } +candle-metal-kernels = { workspace = true, optional = true } metal = { workspace = true, optional = true} cudarc = { workspace = true, optional = true } gemm = { workspace = true } diff --git a/candle-datasets/Cargo.toml b/candle-datasets/Cargo.toml index 69438e0e..ccabf7ed 100644 --- a/candle-datasets/Cargo.toml +++ b/candle-datasets/Cargo.toml @@ -11,8 +11,8 @@ readme = "README.md" [dependencies] byteorder = { workspace = true } -candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" } -candle-nn = { path = "../candle-nn", version = "0.3.3" } +candle = { workspace = true } +candle-nn = { workspace = true } hf-hub = { workspace = true} intel-mkl-src = { workspace = true, optional = true } memmap2 = { workspace = true } diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml index 7e081530..439116f8 100644 --- a/candle-examples/Cargo.toml +++ b/candle-examples/Cargo.toml @@ -11,12 +11,12 @@ readme = "README.md" [dependencies] accelerate-src = { workspace = true, optional = true } -candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" } -candle-datasets = { path = "../candle-datasets", version = "0.3.3" } -candle-nn = { path = "../candle-nn", version = "0.3.3" } -candle-transformers = { path = "../candle-transformers", version = "0.3.3" } -candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.3", optional = true } -candle-onnx = { path = "../candle-onnx", version = "0.3.3", optional = true } +candle = { workspace = true } +candle-datasets = { workspace = true } +candle-nn = { workspace = true } +candle-transformers = { workspace = true } +candle-flash-attn = { workspace = true, optional = true } +candle-onnx = { workspace = true, optional = true } csv = "1.3.0" cudarc = { workspace = true, optional = true } diff --git a/candle-nn/Cargo.toml b/candle-nn/Cargo.toml index 5e0e5c2b..214e8a59 100644 --- a/candle-nn/Cargo.toml +++ b/candle-nn/Cargo.toml @@ -11,7 +11,7 @@ readme = "README.md" [dependencies] accelerate-src = { workspace = true, optional = true } -candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" } +candle = { workspace = true } half = { workspace = true } thiserror = { workspace = true } intel-mkl-src = { workspace = true, optional = true } @@ -20,7 +20,7 @@ rayon = { workspace = true } safetensors = { workspace = true } serde = { workspace = true } metal = { workspace = true, optional = true } -candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.0", optional = true } +candle-metal-kernels = { workspace = true, optional = true } [dev-dependencies] anyhow = { workspace = true } diff --git a/candle-onnx/Cargo.toml b/candle-onnx/Cargo.toml index ba33b07a..cf7add01 100644 --- a/candle-onnx/Cargo.toml +++ b/candle-onnx/Cargo.toml @@ -20,4 +20,3 @@ prost-build = "0.12.1" [dev-dependencies] anyhow = { version = "1", features = ["backtrace"] } clap = { version = "4.2.4", features = ["derive"] } - diff --git a/candle-pyo3/Cargo.toml b/candle-pyo3/Cargo.toml index a03c7559..7c6fbd68 100644 --- a/candle-pyo3/Cargo.toml +++ b/candle-pyo3/Cargo.toml @@ -15,9 +15,9 @@ crate-type = ["cdylib"] [dependencies] accelerate-src = { workspace = true, optional = true } -candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" } -candle-nn = { path = "../candle-nn", version = "0.3.3" } -candle-onnx = {path= "../candle-onnx", version = "0.3.3", optional = true} +candle = { workspace = true } +candle-nn = { workspace = true } +candle-onnx = { workspace = true, optional = true } half = { workspace = true } intel-mkl-src = { workspace = true, optional = true } pyo3 = { version = "0.20.0", features = ["extension-module", "abi3-py38"] } diff --git a/candle-transformers/Cargo.toml b/candle-transformers/Cargo.toml index 83bcff62..1a72c36a 100644 --- a/candle-transformers/Cargo.toml +++ b/candle-transformers/Cargo.toml @@ -12,9 +12,9 @@ readme = "README.md" [dependencies] accelerate-src = { workspace = true, optional = true } byteorder = { workspace = true } -candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" } -candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.3", optional = true } -candle-nn = { path = "../candle-nn", version = "0.3.3" } +candle = { workspace = true } +candle-flash-attn = { workspace = true, optional = true } +candle-nn = { workspace = true } intel-mkl-src = { workspace = true, optional = true } num-traits = { workspace = true } rand = { workspace = true } diff --git a/candle-wasm-examples/bert/Cargo.toml b/candle-wasm-examples/bert/Cargo.toml index 59ce1be3..259a6102 100644 --- a/candle-wasm-examples/bert/Cargo.toml +++ b/candle-wasm-examples/bert/Cargo.toml @@ -9,9 +9,9 @@ categories.workspace = true license.workspace = true [dependencies] -candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" } -candle-nn = { path = "../../candle-nn", version = "0.3.3" } -candle-transformers = { path = "../../candle-transformers", version = "0.3.3" } +candle = { workspace = true } +candle-nn = { workspace = true } +candle-transformers = { workspace = true } num-traits = { workspace = true } tokenizers = { workspace = true, features = ["unstable_wasm"] } diff --git a/candle-wasm-examples/blip/Cargo.toml b/candle-wasm-examples/blip/Cargo.toml index 904e90e6..f4de054e 100644 --- a/candle-wasm-examples/blip/Cargo.toml +++ b/candle-wasm-examples/blip/Cargo.toml @@ -9,9 +9,9 @@ categories.workspace = true license.workspace = true [dependencies] -candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" } -candle-nn = { path = "../../candle-nn", version = "0.3.3" } -candle-transformers = { path = "../../candle-transformers", version = "0.3.3" } +candle = { workspace = true } +candle-nn = { workspace = true } +candle-transformers = { workspace = true } tokenizers = { workspace = true, features = ["unstable_wasm"] } num-traits = { workspace = true } diff --git a/candle-wasm-examples/llama2-c/Cargo.toml b/candle-wasm-examples/llama2-c/Cargo.toml index 63f8a9c5..ac89a558 100644 --- a/candle-wasm-examples/llama2-c/Cargo.toml +++ b/candle-wasm-examples/llama2-c/Cargo.toml @@ -9,9 +9,9 @@ categories.workspace = true license.workspace = true [dependencies] -candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" } -candle-nn = { path = "../../candle-nn", version = "0.3.3" } -candle-transformers = { path = "../../candle-transformers", version = "0.3.3" } +candle = { workspace = true } +candle-nn = { workspace = true } +candle-transformers = { workspace = true } num-traits = { workspace = true } tokenizers = { workspace = true, features = ["unstable_wasm"] } diff --git a/candle-wasm-examples/phi/Cargo.toml b/candle-wasm-examples/phi/Cargo.toml index c4950df9..e437a937 100644 --- a/candle-wasm-examples/phi/Cargo.toml +++ b/candle-wasm-examples/phi/Cargo.toml @@ -9,9 +9,9 @@ categories.workspace = true license.workspace = true [dependencies] -candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" } -candle-nn = { path = "../../candle-nn", version = "0.3.3" } -candle-transformers = { path = "../../candle-transformers", version = "0.3.3" } +candle = { workspace = true } +candle-nn = { workspace = true } +candle-transformers = { workspace = true } tokenizers = { workspace = true, features = ["unstable_wasm"] } num-traits = { workspace = true } diff --git a/candle-wasm-examples/segment-anything/Cargo.toml b/candle-wasm-examples/segment-anything/Cargo.toml index 4d886bc2..1840bb62 100644 --- a/candle-wasm-examples/segment-anything/Cargo.toml +++ b/candle-wasm-examples/segment-anything/Cargo.toml @@ -9,9 +9,9 @@ categories.workspace = true license.workspace = true [dependencies] -candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" } -candle-nn = { path = "../../candle-nn", version = "0.3.3" } -candle-transformers = { path = "../../candle-transformers", version = "0.3.3" } +candle = { workspace = true } +candle-nn = { workspace = true } +candle-transformers = { workspace = true } num-traits = { workspace = true } # App crates. diff --git a/candle-wasm-examples/t5/Cargo.toml b/candle-wasm-examples/t5/Cargo.toml index 237f9e61..36cd9386 100644 --- a/candle-wasm-examples/t5/Cargo.toml +++ b/candle-wasm-examples/t5/Cargo.toml @@ -9,9 +9,9 @@ categories.workspace = true license.workspace = true [dependencies] -candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" } -candle-nn = { path = "../../candle-nn", version = "0.3.3" } -candle-transformers = { path = "../../candle-transformers", version = "0.3.3" } +candle = { workspace = true } +candle-nn = { workspace = true } +candle-transformers = { workspace = true } num-traits = { workspace = true } tokenizers = { workspace = true, features = ["unstable_wasm"] } diff --git a/candle-wasm-examples/whisper/Cargo.toml b/candle-wasm-examples/whisper/Cargo.toml index 5d2b2a38..6c6857e4 100644 --- a/candle-wasm-examples/whisper/Cargo.toml +++ b/candle-wasm-examples/whisper/Cargo.toml @@ -9,9 +9,9 @@ categories.workspace = true license.workspace = true [dependencies] -candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" } -candle-nn = { path = "../../candle-nn", version = "0.3.3" } -candle-transformers = { path = "../../candle-transformers", version = "0.3.3" } +candle = { workspace = true } +candle-nn = { workspace = true } +candle-transformers = { workspace = true } num-traits = { workspace = true } tokenizers = { workspace = true, features = ["unstable_wasm"] } diff --git a/candle-wasm-examples/yolo/Cargo.toml b/candle-wasm-examples/yolo/Cargo.toml index eb2c320b..0e5a91a8 100644 --- a/candle-wasm-examples/yolo/Cargo.toml +++ b/candle-wasm-examples/yolo/Cargo.toml @@ -9,8 +9,8 @@ categories.workspace = true license.workspace = true [dependencies] -candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" } -candle-nn = { path = "../../candle-nn", version = "0.3.3" } +candle = { workspace = true } +candle-nn = { workspace = true } num-traits = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } diff --git a/candle-wasm-tests/Cargo.toml b/candle-wasm-tests/Cargo.toml index a684f2ce..40c37acd 100644 --- a/candle-wasm-tests/Cargo.toml +++ b/candle-wasm-tests/Cargo.toml @@ -7,7 +7,7 @@ keywords.workspace = true categories.workspace = true [dependencies] -candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" } +candle = { workspace = true } rand = { workspace = true } getrandom = { version = "0.2", features = ["js"] } From e72d52b1a2118f8773866e87237586bab762a9c6 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sun, 7 Jan 2024 12:26:20 +0100 Subject: [PATCH 5/8] Unpin more of the workplace relative dependencies. (#1535) --- candle-flash-attn/Cargo.toml | 4 ++-- candle-onnx/Cargo.toml | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/candle-flash-attn/Cargo.toml b/candle-flash-attn/Cargo.toml index 64e690e6..0d3af91d 100644 --- a/candle-flash-attn/Cargo.toml +++ b/candle-flash-attn/Cargo.toml @@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0" readme = "README.md" [dependencies] -candle = { path = "../candle-core", features = ["cuda"], version = "0.3.3", package = "candle-core" } +candle = { path = "../candle-core", features = ["cuda"], package = "candle-core" } half = { version = "2.3.1", features = ["num-traits"] } [build-dependencies] @@ -21,4 +21,4 @@ rayon = "1.7.0" [dev-dependencies] anyhow = { version = "1", features = ["backtrace"] } -candle-nn = { path = "../candle-nn", version = "0.3.3", features = ["cuda"] } +candle-nn = { path = "../candle-nn", features = ["cuda"] } diff --git a/candle-onnx/Cargo.toml b/candle-onnx/Cargo.toml index cf7add01..de1e3350 100644 --- a/candle-onnx/Cargo.toml +++ b/candle-onnx/Cargo.toml @@ -10,8 +10,8 @@ categories = ["science"] license = "MIT OR Apache-2.0" [dependencies] -candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" } -candle-nn = { path = "../candle-nn", version = "0.3.3" } +candle = { path = "../candle-core", package = "candle-core" } +candle-nn = { path = "../candle-nn" } prost = "0.12.1" [build-dependencies] From 30313c308106fff7b20fc8cb2b27eb79800cb818 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Sun, 7 Jan 2024 12:29:24 +0100 Subject: [PATCH 6/8] Moving to a proper build crate `bindgen_cuda`. (#1531) * Moving to a proper build crate `bindgen_cuda`. * Fmt. --- candle-flash-attn/Cargo.toml | 4 +- candle-flash-attn/build.rs | 273 +++++------------------------------ candle-kernels/Cargo.toml | 4 +- candle-kernels/build.rs | 243 +------------------------------ 4 files changed, 41 insertions(+), 483 deletions(-) diff --git a/candle-flash-attn/Cargo.toml b/candle-flash-attn/Cargo.toml index 0d3af91d..d8e8da82 100644 --- a/candle-flash-attn/Cargo.toml +++ b/candle-flash-attn/Cargo.toml @@ -15,9 +15,9 @@ candle = { path = "../candle-core", features = ["cuda"], package = "candle-core" half = { version = "2.3.1", features = ["num-traits"] } [build-dependencies] +bindgen_cuda = "0.1.1" anyhow = { version = "1", features = ["backtrace"] } -num_cpus = "1.15.0" -rayon = "1.7.0" + [dev-dependencies] anyhow = { version = "1", features = ["backtrace"] } diff --git a/candle-flash-attn/build.rs b/candle-flash-attn/build.rs index fde3aeed..4002770b 100644 --- a/candle-flash-attn/build.rs +++ b/candle-flash-attn/build.rs @@ -2,44 +2,32 @@ // The cuda build time is very long so one can set the CANDLE_FLASH_ATTN_BUILD_DIR environment // variable in order to cache the compiled artifacts and avoid recompiling too often. use anyhow::{Context, Result}; -use rayon::prelude::*; use std::path::PathBuf; -use std::str::FromStr; const KERNEL_FILES: [&str; 17] = [ - "flash_api.cu", - "flash_fwd_hdim128_fp16_sm80.cu", - "flash_fwd_hdim160_fp16_sm80.cu", - "flash_fwd_hdim192_fp16_sm80.cu", - "flash_fwd_hdim224_fp16_sm80.cu", - "flash_fwd_hdim256_fp16_sm80.cu", - "flash_fwd_hdim32_fp16_sm80.cu", - "flash_fwd_hdim64_fp16_sm80.cu", - "flash_fwd_hdim96_fp16_sm80.cu", - "flash_fwd_hdim128_bf16_sm80.cu", - "flash_fwd_hdim160_bf16_sm80.cu", - "flash_fwd_hdim192_bf16_sm80.cu", - "flash_fwd_hdim224_bf16_sm80.cu", - "flash_fwd_hdim256_bf16_sm80.cu", - "flash_fwd_hdim32_bf16_sm80.cu", - "flash_fwd_hdim64_bf16_sm80.cu", - "flash_fwd_hdim96_bf16_sm80.cu", + "kernels/flash_api.cu", + "kernels/flash_fwd_hdim128_fp16_sm80.cu", + "kernels/flash_fwd_hdim160_fp16_sm80.cu", + "kernels/flash_fwd_hdim192_fp16_sm80.cu", + "kernels/flash_fwd_hdim224_fp16_sm80.cu", + "kernels/flash_fwd_hdim256_fp16_sm80.cu", + "kernels/flash_fwd_hdim32_fp16_sm80.cu", + "kernels/flash_fwd_hdim64_fp16_sm80.cu", + "kernels/flash_fwd_hdim96_fp16_sm80.cu", + "kernels/flash_fwd_hdim128_bf16_sm80.cu", + "kernels/flash_fwd_hdim160_bf16_sm80.cu", + "kernels/flash_fwd_hdim192_bf16_sm80.cu", + "kernels/flash_fwd_hdim224_bf16_sm80.cu", + "kernels/flash_fwd_hdim256_bf16_sm80.cu", + "kernels/flash_fwd_hdim32_bf16_sm80.cu", + "kernels/flash_fwd_hdim64_bf16_sm80.cu", + "kernels/flash_fwd_hdim96_bf16_sm80.cu", ]; fn main() -> Result<()> { - let num_cpus = std::env::var("RAYON_NUM_THREADS").map_or_else( - |_| num_cpus::get_physical(), - |s| usize::from_str(&s).unwrap(), - ); - - rayon::ThreadPoolBuilder::new() - .num_threads(num_cpus) - .build_global() - .unwrap(); - println!("cargo:rerun-if-changed=build.rs"); for kernel_file in KERNEL_FILES.iter() { - println!("cargo:rerun-if-changed=kernels/{kernel_file}"); + println!("cargo:rerun-if-changed={kernel_file}"); } println!("cargo:rerun-if-changed=kernels/flash_fwd_kernel.h"); println!("cargo:rerun-if-changed=kernels/flash_fwd_launch_template.h"); @@ -66,223 +54,30 @@ fn main() -> Result<()> { )) } }; - set_cuda_include_dir()?; - let ccbin_env = std::env::var("CANDLE_NVCC_CCBIN"); - println!("cargo:rerun-if-env-changed=CANDLE_NVCC_CCBIN"); - - let compute_cap = compute_cap()?; + let kernels = KERNEL_FILES.iter().collect(); + let builder = bindgen_cuda::Builder::default() + .kernel_paths(kernels) + .out_dir(build_dir.clone()) + .arg("-std=c++17") + .arg("-O3") + .arg("-U__CUDA_NO_HALF_OPERATORS__") + .arg("-U__CUDA_NO_HALF_CONVERSIONS__") + .arg("-U__CUDA_NO_HALF2_OPERATORS__") + .arg("-U__CUDA_NO_BFLOAT16_CONVERSIONS__") + .arg("-Icutlass/include") + .arg("--expt-relaxed-constexpr") + .arg("--expt-extended-lambda") + .arg("--use_fast_math") + .arg("--verbose"); let out_file = build_dir.join("libflashattention.a"); + builder.build_lib(out_file); - let kernel_dir = PathBuf::from("kernels"); - let cu_files: Vec<_> = KERNEL_FILES - .iter() - .map(|f| { - let mut obj_file = out_dir.join(f); - obj_file.set_extension("o"); - (kernel_dir.join(f), obj_file) - }) - .collect(); - let out_modified: Result<_, _> = out_file.metadata().and_then(|m| m.modified()); - let should_compile = if out_file.exists() { - kernel_dir - .read_dir() - .expect("kernels folder should exist") - .any(|entry| { - if let (Ok(entry), Ok(out_modified)) = (entry, &out_modified) { - let in_modified = entry.metadata().unwrap().modified().unwrap(); - in_modified.duration_since(*out_modified).is_ok() - } else { - true - } - }) - } else { - true - }; - if should_compile { - cu_files - .par_iter() - .map(|(cu_file, obj_file)| { - let mut command = std::process::Command::new("nvcc"); - command - .arg("-std=c++17") - .arg("-O3") - .arg("-U__CUDA_NO_HALF_OPERATORS__") - .arg("-U__CUDA_NO_HALF_CONVERSIONS__") - .arg("-U__CUDA_NO_HALF2_OPERATORS__") - .arg("-U__CUDA_NO_BFLOAT16_CONVERSIONS__") - .arg(format!("--gpu-architecture=sm_{compute_cap}")) - .arg("-c") - .args(["-o", obj_file.to_str().unwrap()]) - .args(["--default-stream", "per-thread"]) - .arg("-Icutlass/include") - .arg("--expt-relaxed-constexpr") - .arg("--expt-extended-lambda") - .arg("--use_fast_math") - .arg("--verbose"); - if let Ok(ccbin_path) = &ccbin_env { - command - .arg("-allow-unsupported-compiler") - .args(["-ccbin", ccbin_path]); - } - command.arg(cu_file); - let output = command - .spawn() - .context("failed spawning nvcc")? - .wait_with_output()?; - if !output.status.success() { - anyhow::bail!( - "nvcc error while executing compiling: {:?}\n\n# stdout\n{:#}\n\n# stderr\n{:#}", - &command, - String::from_utf8_lossy(&output.stdout), - String::from_utf8_lossy(&output.stderr) - ) - } - Ok(()) - }) - .collect::>()?; - let obj_files = cu_files.iter().map(|c| c.1.clone()).collect::>(); - let mut command = std::process::Command::new("nvcc"); - command - .arg("--lib") - .args(["-o", out_file.to_str().unwrap()]) - .args(obj_files); - let output = command - .spawn() - .context("failed spawning nvcc")? - .wait_with_output()?; - if !output.status.success() { - anyhow::bail!( - "nvcc error while linking: {:?}\n\n# stdout\n{:#}\n\n# stderr\n{:#}", - &command, - String::from_utf8_lossy(&output.stdout), - String::from_utf8_lossy(&output.stderr) - ) - } - } println!("cargo:rustc-link-search={}", build_dir.display()); println!("cargo:rustc-link-lib=flashattention"); println!("cargo:rustc-link-lib=dylib=cudart"); println!("cargo:rustc-link-lib=dylib=stdc++"); - /* laurent: I tried using the cc cuda integration as below but this lead to ptaxs never - finishing to run for some reason. Calling nvcc manually worked fine. - cc::Build::new() - .cuda(true) - .include("cutlass/include") - .flag("--expt-relaxed-constexpr") - .flag("--default-stream") - .flag("per-thread") - .flag(&format!("--gpu-architecture=sm_{compute_cap}")) - .file("kernels/flash_fwd_hdim32_fp16_sm80.cu") - .compile("flashattn"); - */ Ok(()) } - -fn set_cuda_include_dir() -> Result<()> { - // NOTE: copied from cudarc build.rs. - let env_vars = [ - "CUDA_PATH", - "CUDA_ROOT", - "CUDA_TOOLKIT_ROOT_DIR", - "CUDNN_LIB", - ]; - let env_vars = env_vars - .into_iter() - .map(std::env::var) - .filter_map(Result::ok) - .map(Into::::into); - - let roots = [ - "/usr", - "/usr/local/cuda", - "/opt/cuda", - "/usr/lib/cuda", - "C:/Program Files/NVIDIA GPU Computing Toolkit", - "C:/CUDA", - ]; - let roots = roots.into_iter().map(Into::::into); - let root = env_vars - .chain(roots) - .find(|path| path.join("include").join("cuda.h").is_file()) - .context("cannot find include/cuda.h")?; - println!( - "cargo:rustc-env=CUDA_INCLUDE_DIR={}", - root.join("include").display() - ); - Ok(()) -} - -#[allow(unused)] -fn compute_cap() -> Result { - println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP"); - - // Try to parse compute caps from env - let mut compute_cap = if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") { - println!("cargo:rustc-env=CUDA_COMPUTE_CAP={compute_cap_str}"); - compute_cap_str - .parse::() - .context("Could not parse compute cap")? - } else { - // Use nvidia-smi to get the current compute cap - let out = std::process::Command::new("nvidia-smi") - .arg("--query-gpu=compute_cap") - .arg("--format=csv") - .output() - .context("`nvidia-smi` failed. Ensure that you have CUDA installed and that `nvidia-smi` is in your PATH.")?; - let out = std::str::from_utf8(&out.stdout).context("stdout is not a utf8 string")?; - let mut lines = out.lines(); - assert_eq!( - lines.next().context("missing line in stdout")?, - "compute_cap" - ); - let cap = lines - .next() - .context("missing line in stdout")? - .replace('.', ""); - let cap = cap - .parse::() - .with_context(|| format!("cannot parse as int {cap}"))?; - println!("cargo:rustc-env=CUDA_COMPUTE_CAP={cap}"); - cap - }; - - // Grab available GPU codes from nvcc and select the highest one - let (supported_nvcc_codes, max_nvcc_code) = { - let out = std::process::Command::new("nvcc") - .arg("--list-gpu-code") - .output() - .expect("`nvcc` failed. Ensure that you have CUDA installed and that `nvcc` is in your PATH."); - let out = std::str::from_utf8(&out.stdout).unwrap(); - - let out = out.lines().collect::>(); - let mut codes = Vec::with_capacity(out.len()); - for code in out { - let code = code.split('_').collect::>(); - if !code.is_empty() && code.contains(&"sm") { - if let Ok(num) = code[1].parse::() { - codes.push(num); - } - } - } - codes.sort(); - let max_nvcc_code = *codes.last().context("no gpu codes parsed from nvcc")?; - (codes, max_nvcc_code) - }; - - // Check that nvcc supports the asked compute caps - if !supported_nvcc_codes.contains(&compute_cap) { - anyhow::bail!( - "nvcc cannot target gpu arch {compute_cap}. Available nvcc targets are {supported_nvcc_codes:?}." - ); - } - if compute_cap > max_nvcc_code { - anyhow::bail!( - "CUDA compute cap {compute_cap} is higher than the highest gpu code from nvcc {max_nvcc_code}" - ); - } - - Ok(compute_cap) -} diff --git a/candle-kernels/Cargo.toml b/candle-kernels/Cargo.toml index e81fe39c..0cd4a14d 100644 --- a/candle-kernels/Cargo.toml +++ b/candle-kernels/Cargo.toml @@ -12,6 +12,4 @@ license = "MIT OR Apache-2.0" [dependencies] [build-dependencies] -anyhow = { version = "1", features = ["backtrace"] } -glob = "0.3.1" -rayon = "1.7.0" +bindgen_cuda = "0.1.1" diff --git a/candle-kernels/build.rs b/candle-kernels/build.rs index 17a0bf9c..63d744ca 100644 --- a/candle-kernels/build.rs +++ b/candle-kernels/build.rs @@ -1,243 +1,8 @@ -use std::io::Write; - fn main() { println!("cargo:rerun-if-changed=build.rs"); - cuda::set_include_dir(); - let (write, kernel_paths) = cuda::build_ptx(); - if write { - let mut file = std::fs::File::create("src/lib.rs").unwrap(); - for kernel_path in kernel_paths { - let name = kernel_path.file_stem().unwrap().to_str().unwrap(); - file.write_all( - format!( - r#"pub const {}: &str = include_str!(concat!(env!("OUT_DIR"), "/{}.ptx"));"#, - name.to_uppercase().replace('.', "_"), - name - ) - .as_bytes(), - ) - .unwrap(); - file.write_all(&[b'\n']).unwrap(); - } - } -} - -mod cuda { - use anyhow::{Context, Result}; - - pub fn set_include_dir() { - use std::path::PathBuf; - // NOTE: copied from cudarc build.rs. - // We can't actually set a env!() value from another crate, - // so we have to do that here. - - // use PathBuf; - - let env_vars = [ - "CUDA_PATH", - "CUDA_ROOT", - "CUDA_TOOLKIT_ROOT_DIR", - "CUDNN_LIB", - ]; - #[allow(unused)] - let env_vars = env_vars - .into_iter() - .map(std::env::var) - .filter_map(Result::ok) - .map(Into::::into); - - let roots = [ - "/usr", - "/usr/local/cuda", - "/opt/cuda", - "/usr/lib/cuda", - "C:/Program Files/NVIDIA GPU Computing Toolkit", - "C:/CUDA", - ]; - #[allow(unused)] - let roots = roots.into_iter().map(Into::::into); - - #[cfg(feature = "ci-check")] - let root: PathBuf = "ci".into(); - - #[cfg(not(feature = "ci-check"))] - let root = env_vars - .chain(roots) - .find(|path| path.join("include").join("cuda.h").is_file()) - .unwrap(); - - println!( - "cargo:rustc-env=CUDA_INCLUDE_DIR={}", - root.join("include").display() - ); - } - - pub fn build_ptx() -> (bool, Vec) { - use rayon::prelude::*; - use std::path::PathBuf; - let out_dir = std::env::var("OUT_DIR").unwrap(); - let kernel_paths: Vec = glob::glob("src/*.cu") - .unwrap() - .map(|p| p.unwrap()) - .collect(); - let mut include_directories: Vec = glob::glob("src/**/*.cuh") - .unwrap() - .map(|p| p.unwrap()) - .collect(); - - println!("cargo:rerun-if-changed=src/"); - // for path in &kernel_paths { - // println!("cargo:rerun-if-changed={}", path.display()); - // } - - for path in &mut include_directories { - // println!("cargo:rerun-if-changed={}", path.display()); - let destination = - std::format!("{out_dir}/{}", path.file_name().unwrap().to_str().unwrap()); - std::fs::copy(path.clone(), destination).unwrap(); - // remove the filename from the path so it's just the directory - path.pop(); - } - - include_directories.sort(); - include_directories.dedup(); - - let compute_cap = compute_cap().expect("Could not get Cuda compute cap"); - - #[allow(unused)] - let include_options: Vec = include_directories - .into_iter() - .map(|s| "-I".to_string() + &s.into_os_string().into_string().unwrap()) - .collect::>(); - - let ccbin_env = std::env::var("CANDLE_NVCC_CCBIN"); - println!("cargo:rerun-if-env-changed=CANDLE_NVCC_CCBIN"); - let children = kernel_paths - .par_iter() - .flat_map(|p| { - let mut output = p.clone(); - output.set_extension("ptx"); - let output_filename = std::path::Path::new(&out_dir).to_path_buf().join("out").with_file_name(output.file_name().unwrap()); - - let ignore = if output_filename.exists() { - let out_modified = output_filename.metadata().unwrap().modified().unwrap(); - let in_modified = p.metadata().unwrap().modified().unwrap(); - out_modified.duration_since(in_modified).is_ok() - } else { - false - }; - if ignore { - None - } else { - let mut command = std::process::Command::new("nvcc"); - command.arg(format!("--gpu-architecture=sm_{compute_cap}")) - .arg("--ptx") - .args(["--default-stream", "per-thread"]) - .args(["--output-directory", &out_dir]) - // Flash attention only - // .arg("--expt-relaxed-constexpr") - .args(&include_options); - if let Ok(ccbin_path) = &ccbin_env { - command - .arg("-allow-unsupported-compiler") - .args(["-ccbin", ccbin_path]); - } - command.arg(p); - Some((p, command.spawn() - .expect("nvcc failed to start. Ensure that you have CUDA installed and that `nvcc` is in your PATH.").wait_with_output())) - } - }) - .collect::>(); - - let ptx_paths: Vec = glob::glob(&format!("{out_dir}/**/*.ptx")) - .unwrap() - .map(|p| p.unwrap()) - .collect(); - // We should rewrite `src/lib.rs` only if there are some newly compiled kernels, or removed - // some old ones - let write = !children.is_empty() || kernel_paths.len() < ptx_paths.len(); - for (kernel_path, child) in children { - let output = child.expect("nvcc failed to run. Ensure that you have CUDA installed and that `nvcc` is in your PATH."); - assert!( - output.status.success(), - "nvcc error while compiling {kernel_path:?}:\n\n# stdout\n{:#}\n\n# stderr\n{:#}", - String::from_utf8_lossy(&output.stdout), - String::from_utf8_lossy(&output.stderr) - ); - } - (write, kernel_paths) - } - - #[allow(unused)] - fn compute_cap() -> Result { - println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP"); - - // Try to parse compute caps from env - let mut compute_cap = if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") { - println!("cargo:rustc-env=CUDA_COMPUTE_CAP={compute_cap_str}"); - compute_cap_str - .parse::() - .context("Could not parse code")? - } else { - // Use nvidia-smi to get the current compute cap - let out = std::process::Command::new("nvidia-smi") - .arg("--query-gpu=compute_cap") - .arg("--format=csv") - .output() - .context("`nvidia-smi` failed. Ensure that you have CUDA installed and that `nvidia-smi` is in your PATH.")?; - let out = std::str::from_utf8(&out.stdout).context("stdout is not a utf8 string")?; - let mut lines = out.lines(); - assert_eq!( - lines.next().context("missing line in stdout")?, - "compute_cap" - ); - let cap = lines - .next() - .context("missing line in stdout")? - .replace('.', ""); - let cap = cap - .parse::() - .with_context(|| format!("cannot parse as int {cap}"))?; - println!("cargo:rustc-env=CUDA_COMPUTE_CAP={cap}"); - cap - }; - - // Grab available GPU codes from nvcc and select the highest one - let (supported_nvcc_codes, max_nvcc_code) = { - let out = std::process::Command::new("nvcc") - .arg("--list-gpu-code") - .output() - .expect("`nvcc` failed. Ensure that you have CUDA installed and that `nvcc` is in your PATH."); - let out = std::str::from_utf8(&out.stdout).unwrap(); - - let out = out.lines().collect::>(); - let mut codes = Vec::with_capacity(out.len()); - for code in out { - let code = code.split('_').collect::>(); - if !code.is_empty() && code.contains(&"sm") { - if let Ok(num) = code[1].parse::() { - codes.push(num); - } - } - } - codes.sort(); - let max_nvcc_code = *codes.last().context("no gpu codes parsed from nvcc")?; - (codes, max_nvcc_code) - }; - - // Check that nvcc supports the asked compute caps - if !supported_nvcc_codes.contains(&compute_cap) { - anyhow::bail!( - "nvcc cannot target gpu arch {compute_cap}. Available nvcc targets are {supported_nvcc_codes:?}." - ); - } - if compute_cap > max_nvcc_code { - anyhow::bail!( - "CUDA compute cap {compute_cap} is higher than the highest gpu code from nvcc {max_nvcc_code}" - ); - } - - Ok(compute_cap) - } + let builder = bindgen_cuda::Builder::default(); + println!("cargo:info={builder:?}"); + let bindings = builder.build_ptx().unwrap(); + bindings.write("src/lib.rs").unwrap(); } From 89b5a068585b73193d2004a7293d5b2fa6c30bfd Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sun, 7 Jan 2024 17:18:46 +0100 Subject: [PATCH 7/8] Use bindgen-cuda for the custom-kernel example. (#1536) * Use bindgen-cuda for the custom-kernel example. * Only depend on the kernels when cuda is enabled. * Skip rustfmt. --- candle-examples/Cargo.toml | 3 +- candle-examples/build.rs | 247 ++---------------- .../examples/custom-ops/cuda_kernels.rs | 3 +- candle-examples/examples/custom-ops/main.rs | 3 +- 4 files changed, 20 insertions(+), 236 deletions(-) diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml index 439116f8..00340d08 100644 --- a/candle-examples/Cargo.toml +++ b/candle-examples/Cargo.toml @@ -49,11 +49,12 @@ tokio = "1.29.1" [build-dependencies] anyhow = { workspace = true } +bindgen_cuda = { version = "0.1.1", optional = true } [features] default = [] accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"] -cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"] +cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda", "dep:bindgen_cuda"] cudnn = ["candle/cudnn"] flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"] mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"] diff --git a/candle-examples/build.rs b/candle-examples/build.rs index 0af3a6a4..ba40aeb4 100644 --- a/candle-examples/build.rs +++ b/candle-examples/build.rs @@ -4,251 +4,34 @@ use std::io::Write; use std::path::PathBuf; struct KernelDirectories { - kernel_dir: &'static str, + kernel_glob: &'static str, rust_target: &'static str, include_dirs: &'static [&'static str], } -const DIRS: [KernelDirectories; 1] = [KernelDirectories { - kernel_dir: "examples/custom-ops/kernels/", +const KERNEL_DIRS: [KernelDirectories; 1] = [KernelDirectories { + kernel_glob: "examples/custom-ops/kernels/*.cu", rust_target: "examples/custom-ops/cuda_kernels.rs", include_dirs: &[], }]; -impl KernelDirectories { - fn maybe_build_ptx( - &self, - cu_file: &std::path::Path, - ptx_file: &std::path::Path, - compute_cap: usize, - ) -> Result<()> { - let should_compile = if ptx_file.exists() { - let ptx_modified = ptx_file.metadata()?.modified()?; - let cu_modified = cu_file.metadata()?.modified()?; - cu_modified.duration_since(ptx_modified).is_ok() - } else { - true - }; - if should_compile { - #[cfg(feature = "cuda")] - { - let ccbin_env = std::env::var("CANDLE_NVCC_CCBIN"); - println!("cargo:rerun-if-env-changed=CANDLE_NVCC_CCBIN"); - let mut command = std::process::Command::new("nvcc"); - let out_dir = ptx_file.parent().context("no parent for ptx file")?; - let include_dirs: Vec = - self.include_dirs.iter().map(|c| format!("-I{c}")).collect(); - command - .arg(format!("--gpu-architecture=sm_{compute_cap}")) - .arg("--ptx") - .args(["--default-stream", "per-thread"]) - .args(["--output-directory", out_dir.to_str().unwrap()]) - .arg(format!("-I/{}", self.kernel_dir)) - .args(include_dirs) - .arg(cu_file); - if let Ok(ccbin_path) = &ccbin_env { - command - .arg("-allow-unsupported-compiler") - .args(["-ccbin", ccbin_path]); - } - let output = command - .spawn() - .context("failed spawning nvcc")? - .wait_with_output()?; - if !output.status.success() { - anyhow::bail!( - "nvcc error while compiling {cu_file:?}:\n\n# stdout\n{:#}\n\n# stderr\n{:#}", - String::from_utf8_lossy(&output.stdout), - String::from_utf8_lossy(&output.stderr) - ) - } - } - #[cfg(not(feature = "cuda"))] - std::fs::OpenOptions::new() - .create(true) - .write(true) - .open(ptx_file)?; - } - Ok(()) - } - fn process(&self, out_dir: &std::path::Path, compute_cap: usize) -> Result<()> { - println!("cargo:rerun-if-changed={}", self.kernel_dir); - let kernel_dir = PathBuf::from(self.kernel_dir); - let out_dir = out_dir.join(self.kernel_dir); - if !out_dir.exists() { - std::fs::create_dir_all(&out_dir)?; - } - let mut cu_files = vec![]; - let mut cuh_files = vec![]; - for file in std::fs::read_dir(kernel_dir)?.flatten() { - let file = file.path(); - match file.extension().and_then(|v| v.to_str()) { - Some("cu") => cu_files.push(file), - Some("cuh") => cuh_files.push(file), - _ => {} - } - } - - let mut ptx_paths = vec![]; - for cu_file in cu_files.iter() { - let file_stem = cu_file - .file_stem() - .with_context(|| format!("no stem {cu_file:?}"))?; - let file_stem = file_stem.to_string_lossy().into_owned(); - let ptx_file = out_dir.join(&format!("{file_stem}.ptx")); - self.maybe_build_ptx(cu_file, &ptx_file, compute_cap)?; - ptx_paths.push(ptx_file); - } - - let regenerate_rs_file = true; - if regenerate_rs_file { - let mut file = std::fs::File::create(self.rust_target)?; - for ptx_path in ptx_paths { - let name = ptx_path - .file_stem() - .context("empty stem")? - .to_string_lossy(); - file.write_all(b"#[rustfmt::skip]\n")?; - let const_definition = format!( - r#"pub const {}: &str = include_str!(concat!(env!("OUT_DIR"), "/{}/{name}.ptx"));"#, - name.to_uppercase().replace('.', "_"), - self.kernel_dir, - ); - file.write_all(const_definition.as_bytes())?; - file.write_all(b"\n")?; - } - } - Ok(()) - } -} - fn main() -> Result<()> { println!("cargo:rerun-if-changed=build.rs"); - let out_dir = std::env::var("OUT_DIR").context("OUT_DIR not set")?; - let out_dir = PathBuf::from(out_dir); #[cfg(feature = "cuda")] - set_cuda_include_dir()?; - #[cfg(feature = "cuda")] - let compute_cap = compute_cap()?; + { + for kdir in KERNEL_DIRS.iter() { + let builder = bindgen_cuda::Builder::default().kernel_paths_glob(kdir.kernel_glob); + println!("cargo:info={builder:?}"); + let bindings = builder.build_ptx().unwrap(); + bindings.write(kdir.rust_target).unwrap() + } + } #[cfg(not(feature = "cuda"))] - let compute_cap = 0; - for d in DIRS { - d.process(&out_dir, compute_cap)? + { + for kdir in KERNEL_DIRS.iter() { + let _file = std::fs::File::create(kdir.rust_target)?; + } } Ok(()) } - -fn set_cuda_include_dir() -> Result<()> { - // NOTE: copied from cudarc build.rs. - let env_vars = [ - "CUDA_PATH", - "CUDA_ROOT", - "CUDA_TOOLKIT_ROOT_DIR", - "CUDNN_LIB", - ]; - let env_vars = env_vars - .into_iter() - .map(std::env::var) - .filter_map(Result::ok) - .map(Into::::into); - - let roots = [ - "/usr", - "/usr/local/cuda", - "/opt/cuda", - "/usr/lib/cuda", - "C:/Program Files/NVIDIA GPU Computing Toolkit", - "C:/CUDA", - ]; - let roots = roots.into_iter().map(Into::::into); - let root = env_vars - .chain(roots) - .find(|path| path.join("include").join("cuda.h").is_file()) - .context("cannot find include/cuda.h")?; - println!( - "cargo:rustc-env=CUDA_INCLUDE_DIR={}", - root.join("include").display() - ); - Ok(()) -} - -#[allow(unused)] -fn compute_cap() -> Result { - println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP"); - - // Try to parse compute cap from env - let mut compute_cap = if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") { - println!("cargo:rustc-env=CUDA_COMPUTE_CAP={compute_cap_str}"); - compute_cap_str - .parse::() - .context("Could not parse code")? - } else { - // Grab compute cap from nvidia-smi - let out = std::process::Command::new("nvidia-smi") - .arg("--query-gpu=compute_cap") - .arg("--format=csv") - .output() - .context("`nvidia-smi` failed. Ensure that you have CUDA installed and that `nvidia-smi` is in your PATH.")?; - let out = std::str::from_utf8(&out.stdout).context("stdout is not a utf8 string")?; - let mut lines = out.lines(); - assert_eq!( - lines.next().context("missing line in stdout")?, - "compute_cap" - ); - let cap = lines - .next() - .context("missing line in stdout")? - .replace('.', ""); - println!("cargo:rustc-env=CUDA_COMPUTE_CAP={cap}"); - cap.parse::() - .with_context(|| format!("cannot parse as int {cap}"))? - }; - - // Grab available GPU codes from nvcc and select the highest one - let max_nvcc_code = { - let out = std::process::Command::new("nvcc") - .arg("--list-gpu-code") - .output() - .expect("`nvcc` failed. Ensure that you have CUDA installed and that `nvcc` is in your PATH."); - let out = std::str::from_utf8(&out.stdout).unwrap(); - - let out = out.lines().collect::>(); - let mut codes = Vec::with_capacity(out.len()); - for code in out { - let code = code.split('_').collect::>(); - if !code.is_empty() && code.contains(&"sm") { - if let Ok(num) = code[1].parse::() { - codes.push(num); - } - } - } - codes.sort(); - if !codes.contains(&compute_cap) { - anyhow::bail!( - "nvcc cannot target gpu arch {compute_cap}. Available nvcc targets are {codes:?}." - ); - } - *codes.last().unwrap() - }; - - // If nvidia-smi compute_cap is higher than the highest gpu code from nvcc, - // then choose the highest gpu code in nvcc - if compute_cap > max_nvcc_code { - println!( - "cargo:warning=Lowering gpu arch {compute_cap} to max nvcc target {max_nvcc_code}." - ); - compute_cap = max_nvcc_code; - } - - println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP"); - - if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") { - compute_cap = compute_cap_str - .parse::() - .with_context(|| format!("cannot parse as usize '{compute_cap_str}'"))?; - println!("cargo:warning=Using gpu arch {compute_cap} from $CUDA_COMPUTE_CAP"); - } - println!("cargo:rustc-env=CUDA_COMPUTE_CAP=sm_{compute_cap}"); - Ok(compute_cap) -} diff --git a/candle-examples/examples/custom-ops/cuda_kernels.rs b/candle-examples/examples/custom-ops/cuda_kernels.rs index 0bee73aa..c00b601b 100644 --- a/candle-examples/examples/custom-ops/cuda_kernels.rs +++ b/candle-examples/examples/custom-ops/cuda_kernels.rs @@ -1,2 +1 @@ -#[rustfmt::skip] -pub const LAYERNORM_KERNELS: &str = include_str!(concat!(env!("OUT_DIR"), "/examples/custom-ops/kernels//layernorm_kernels.ptx")); +pub const LAYERNORM_KERNELS: &str = include_str!(concat!(env!("OUT_DIR"), "/layernorm_kernels.ptx")); diff --git a/candle-examples/examples/custom-ops/main.rs b/candle-examples/examples/custom-ops/main.rs index f2f534dc..30e413c1 100644 --- a/candle-examples/examples/custom-ops/main.rs +++ b/candle-examples/examples/custom-ops/main.rs @@ -6,7 +6,8 @@ #[cfg(feature = "mkl")] extern crate intel_mkl_src; -#[allow(unused)] +#[rustfmt::skip] +#[cfg(feature = "cuda")] mod cuda_kernels; use clap::Parser; From 0eb90ed7831d451e2e420ecd158151b44dc5b2ba Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sun, 7 Jan 2024 20:21:49 +0100 Subject: [PATCH 8/8] Simpler repro for the neon optimization issue + bugfix (#1544) * Simpler repro for the neon optimization issue. * Bugfix for q4k. * Improve the fix, share the dot-prod bit. * Clippy fixes. * Fix for q6k. * Also fix for q2k. * Use the new shared dotprod. * Add more testing. --- candle-core/src/quantized/neon.rs | 208 ++++++++------------------- candle-core/tests/quantized_tests.rs | 57 +++++--- 2 files changed, 97 insertions(+), 168 deletions(-) diff --git a/candle-core/src/quantized/neon.rs b/candle-core/src/quantized/neon.rs index 3cb56229..c4d5d6f4 100644 --- a/candle-core/src/quantized/neon.rs +++ b/candle-core/src/quantized/neon.rs @@ -12,6 +12,14 @@ use core::arch::arm::*; #[cfg(target_arch = "aarch64")] use core::arch::aarch64::*; +#[inline(always)] +unsafe fn vdotq_s32(a: int8x16_t, b: int8x16_t) -> int32x4_t { + // TODO: dotprod + let p0 = vmull_s8(vget_low_s8(a), vget_low_s8(b)); + let p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b)); + vaddq_s32(vpaddlq_s16(p0), vpaddlq_s16(p1)) +} + #[inline(always)] pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Result { let qk = QK8_0; @@ -43,15 +51,8 @@ pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> let v1_0l = vld1q_s8(y0.qs.as_ptr()); let v1_0h = vld1q_s8(y0.qs.as_ptr().add(16)); - // TODO: Support dotprod when it's available outside of nightly. - let pl0l = vmull_s8(vget_low_s8(v0_0ls), vget_low_s8(v1_0l)); - let pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0l)); - let ph0l = vmull_s8(vget_low_s8(v0_0hs), vget_low_s8(v1_0h)); - let ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0h)); - - let pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h)); - let ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h)); - + let pl0 = vdotq_s32(v0_0ls, v1_0l); + let ph0 = vdotq_s32(v0_0hs, v1_0h); sumv0 = vmlaq_n_f32( sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), @@ -82,14 +83,8 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) -> let y0_0 = vld1q_s8(y0.qs.as_ptr()); let y0_1 = vld1q_s8(y0.qs.as_ptr().add(16)); - // TODO dotprod once this is the intrinsics are. - let p0_0 = vmull_s8(vget_low_s8(x0_0), vget_low_s8(y0_0)); - let p0_1 = vmull_s8(vget_high_s8(x0_0), vget_high_s8(y0_0)); - let p0_2 = vmull_s8(vget_low_s8(x0_1), vget_low_s8(y0_1)); - let p0_3 = vmull_s8(vget_high_s8(x0_1), vget_high_s8(y0_1)); - - let p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1)); - let p1 = vaddq_s32(vpaddlq_s16(p0_2), vpaddlq_s16(p0_3)); + let p0 = vdotq_s32(x0_0, y0_0); + let p1 = vdotq_s32(x0_1, y0_1); sumv0 = vmlaq_n_f32( sumv0, @@ -118,10 +113,7 @@ pub(crate) fn vec_dot_q8k_q8k(n: usize, xs: &[BlockQ8K], ys: &[BlockQ8K]) -> Res for i in (0..QK_K).step_by(16) { let xs = vld1q_s8(xs.add(i)); let ys = vld1q_s8(ys.add(i)); - let xy_lo = vmull_s8(vget_low_s8(xs), vget_low_s8(ys)); - let xy_up = vmull_s8(vget_high_s8(xs), vget_high_s8(ys)); - - let xy = vaddq_s32(vpaddlq_s16(xy_lo), vpaddlq_s16(xy_up)); + let xy = vdotq_s32(xs, ys); sum_i = vaddq_s32(sum_i, xy) } sumf += vaddvq_s32(sum_i) as f32 * scale @@ -191,30 +183,16 @@ pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Res let q6bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.2, m4b), q6h_2)); let q6bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.3, m4b), q6h_3)); - // TODO: dotprod - - let p0 = vaddq_s16( - vmull_s8(vget_low_s8(q6bytes_0), vget_low_s8(q8bytes.0)), - vmull_s8(vget_high_s8(q6bytes_0), vget_high_s8(q8bytes.0)), - ); - let p1 = vaddq_s16( - vmull_s8(vget_low_s8(q6bytes_1), vget_low_s8(q8bytes.1)), - vmull_s8(vget_high_s8(q6bytes_1), vget_high_s8(q8bytes.1)), - ); + let p0 = vdotq_s32(q6bytes_0, q8bytes.0); + let p1 = vdotq_s32(q6bytes_1, q8bytes.1); let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32); - isum += vaddvq_s16(p0) as i32 * scale0 + vaddvq_s16(p1) as i32 * scale1; + isum += vaddvq_s32(p0) * scale0 + vaddvq_s32(p1) * scale1; scale = scale.add(2); - let p2 = vaddq_s16( - vmull_s8(vget_low_s8(q6bytes_2), vget_low_s8(q8bytes.2)), - vmull_s8(vget_high_s8(q6bytes_2), vget_high_s8(q8bytes.2)), - ); - let p3 = vaddq_s16( - vmull_s8(vget_low_s8(q6bytes_3), vget_low_s8(q8bytes.3)), - vmull_s8(vget_high_s8(q6bytes_3), vget_high_s8(q8bytes.3)), - ); + let p2 = vdotq_s32(q6bytes_2, q8bytes.2); + let p3 = vdotq_s32(q6bytes_3, q8bytes.3); let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32); - isum += vaddvq_s16(p2) as i32 * scale0 + vaddvq_s16(p3) as i32 * scale1; + isum += vaddvq_s32(p2) * scale0 + vaddvq_s32(p3) * scale1; scale = scale.add(2); let q8bytes = vld1q_s8_x4(q8); @@ -234,29 +212,16 @@ pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Res let q6bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.2, 4), q6h_2)); let q6bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.3, 4), q6h_3)); - // TODO: dotprod case. - let p0 = vaddq_s16( - vmull_s8(vget_low_s8(q6bytes_0), vget_low_s8(q8bytes.0)), - vmull_s8(vget_high_s8(q6bytes_0), vget_high_s8(q8bytes.0)), - ); - let p1 = vaddq_s16( - vmull_s8(vget_low_s8(q6bytes_1), vget_low_s8(q8bytes.1)), - vmull_s8(vget_high_s8(q6bytes_1), vget_high_s8(q8bytes.1)), - ); + let p0 = vdotq_s32(q6bytes_0, q8bytes.0); + let p1 = vdotq_s32(q6bytes_1, q8bytes.1); let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32); - isum += vaddvq_s16(p0) as i32 * scale0 + vaddvq_s16(p1) as i32 * scale1; + isum += vaddvq_s32(p0) * scale0 + vaddvq_s32(p1) * scale1; scale = scale.add(2); - let p2 = vaddq_s16( - vmull_s8(vget_low_s8(q6bytes_2), vget_low_s8(q8bytes.2)), - vmull_s8(vget_high_s8(q6bytes_2), vget_high_s8(q8bytes.2)), - ); - let p3 = vaddq_s16( - vmull_s8(vget_low_s8(q6bytes_3), vget_low_s8(q8bytes.3)), - vmull_s8(vget_high_s8(q6bytes_3), vget_high_s8(q8bytes.3)), - ); + let p2 = vdotq_s32(q6bytes_2, q8bytes.2); + let p3 = vdotq_s32(q6bytes_3, q8bytes.3); let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32); - isum += vaddvq_s16(p2) as i32 * scale0 + vaddvq_s16(p3) as i32 * scale1; + isum += vaddvq_s32(p2) * scale0 + vaddvq_s32(p3) * scale1; scale = scale.add(2); } sum += d_all * y.d * ((isum - 32 * isum_mins) as f32); @@ -333,28 +298,14 @@ pub(crate) fn vec_dot_q5k_q8k(n: usize, xs: &[BlockQ5K], ys: &[BlockQ8K]) -> Res let q5bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.0, 4), q5h_2)); let q5bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.1, 4), q5h_3)); - // TODO: dotprod - - let p0 = vaddq_s16( - vmull_s8(vget_low_s8(q5bytes_0), vget_low_s8(q8bytes.0)), - vmull_s8(vget_high_s8(q5bytes_0), vget_high_s8(q8bytes.0)), - ); - let p1 = vaddq_s16( - vmull_s8(vget_low_s8(q5bytes_1), vget_low_s8(q8bytes.1)), - vmull_s8(vget_high_s8(q5bytes_1), vget_high_s8(q8bytes.1)), - ); - sumi += vaddvq_s16(vaddq_s16(p0, p1)) as i32 * *scales as i32; + let p0 = vdotq_s32(q5bytes_0, q8bytes.0); + let p1 = vdotq_s32(q5bytes_1, q8bytes.1); + sumi += vaddvq_s32(vaddq_s32(p0, p1)) * *scales as i32; scales = scales.add(1); - let p2 = vaddq_s16( - vmull_s8(vget_low_s8(q5bytes_2), vget_low_s8(q8bytes.2)), - vmull_s8(vget_high_s8(q5bytes_2), vget_high_s8(q8bytes.2)), - ); - let p3 = vaddq_s16( - vmull_s8(vget_low_s8(q5bytes_3), vget_low_s8(q8bytes.3)), - vmull_s8(vget_high_s8(q5bytes_3), vget_high_s8(q8bytes.3)), - ); - sumi += vaddvq_s16(vaddq_s16(p2, p3)) as i32 * *scales as i32; + let p2 = vdotq_s32(q5bytes_2, q8bytes.2); + let p3 = vdotq_s32(q5bytes_3, q8bytes.3); + sumi += vaddvq_s32(vaddq_s32(p2, p3)) * *scales as i32; scales = scales.add(1); } sumf += d * sumi as f32 - dmin * sumi_mins as f32; @@ -417,22 +368,15 @@ pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Res for j in 0..QK_K / 64 { let q4bits = vld1q_u8_x2(q4); q4 = q4.add(32); - // TODO: dotprod let q8bytes = vld1q_s8_x2(q8); q8 = q8.add(32); let q4bytes = int8x16x2_t( vreinterpretq_s8_u8(vandq_u8(q4bits.0, m4b)), vreinterpretq_s8_u8(vandq_u8(q4bits.1, m4b)), ); - let p0 = vaddq_s16( - vmull_s8(vget_low_s8(q4bytes.0), vget_low_s8(q8bytes.0)), - vmull_s8(vget_high_s8(q4bytes.0), vget_high_s8(q8bytes.0)), - ); - let p1 = vaddq_s16( - vmull_s8(vget_low_s8(q4bytes.1), vget_low_s8(q8bytes.1)), - vmull_s8(vget_high_s8(q4bytes.1), vget_high_s8(q8bytes.1)), - ); - sumi1 += vaddvq_s16(vaddq_s16(p0, p1)) as i32 * scales[2 * j] as i32; + let p0 = vdotq_s32(q4bytes.0, q8bytes.0); + let p1 = vdotq_s32(q4bytes.1, q8bytes.1); + sumi1 += vaddvq_s32(vaddq_s32(p0, p1)) * scales[2 * j] as i32; let q8bytes = vld1q_s8_x2(q8); q8 = q8.add(32); @@ -440,15 +384,9 @@ pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Res vreinterpretq_s8_u8(vshrq_n_u8(q4bits.0, 4)), vreinterpretq_s8_u8(vshrq_n_u8(q4bits.1, 4)), ); - let p2 = vaddq_s16( - vmull_s8(vget_low_s8(q4bytes.0), vget_low_s8(q8bytes.0)), - vmull_s8(vget_high_s8(q4bytes.0), vget_high_s8(q8bytes.0)), - ); - let p3 = vaddq_s16( - vmull_s8(vget_low_s8(q4bytes.1), vget_low_s8(q8bytes.1)), - vmull_s8(vget_high_s8(q4bytes.1), vget_high_s8(q8bytes.1)), - ); - sumi2 += vaddvq_s16(vaddq_s16(p2, p3)) as i32 * scales[2 * j + 1] as i32; + let p2 = vdotq_s32(q4bytes.0, q8bytes.0); + let p3 = vdotq_s32(q4bytes.1, q8bytes.1); + sumi2 += vaddvq_s32(vaddq_s32(p2, p3)) * scales[2 * j + 1] as i32; } sumf += d * (sumi1 + sumi2) as f32; } @@ -526,27 +464,14 @@ pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Res vreinterpretq_s8_u8(q3h_3), ); - // TODO: dotprod - let p0 = vaddq_s16( - vmull_s8(vget_low_s8(q3bytes_0), vget_low_s8(q8bytes_1.0)), - vmull_s8(vget_high_s8(q3bytes_0), vget_high_s8(q8bytes_1.0)), - ); - let p1 = vaddq_s16( - vmull_s8(vget_low_s8(q3bytes_1), vget_low_s8(q8bytes_1.1)), - vmull_s8(vget_high_s8(q3bytes_1), vget_high_s8(q8bytes_1.1)), - ); - let p2 = vaddq_s16( - vmull_s8(vget_low_s8(q3bytes_2), vget_low_s8(q8bytes_1.2)), - vmull_s8(vget_high_s8(q3bytes_2), vget_high_s8(q8bytes_1.2)), - ); - let p3 = vaddq_s16( - vmull_s8(vget_low_s8(q3bytes_3), vget_low_s8(q8bytes_1.3)), - vmull_s8(vget_high_s8(q3bytes_3), vget_high_s8(q8bytes_1.3)), - ); - isum += vaddvq_s16(p0) as i32 * *scale as i32 - + vaddvq_s16(p1) as i32 * *scale.add(1) as i32 - + vaddvq_s16(p2) as i32 * *scale.add(2) as i32 - + vaddvq_s16(p3) as i32 * *scale.add(3) as i32; + let p0 = vdotq_s32(q3bytes_0, q8bytes_1.0); + let p1 = vdotq_s32(q3bytes_1, q8bytes_1.1); + let p2 = vdotq_s32(q3bytes_2, q8bytes_1.2); + let p3 = vdotq_s32(q3bytes_3, q8bytes_1.3); + isum += vaddvq_s32(p0) * *scale as i32 + + vaddvq_s32(p1) * *scale.add(1) as i32 + + vaddvq_s32(p2) * *scale.add(2) as i32 + + vaddvq_s32(p3) * *scale.add(3) as i32; scale = scale.add(4); let q3h_0 = vbicq_u8(m2, qhbits.0); @@ -571,27 +496,14 @@ pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Res vreinterpretq_s8_u8(q3h_3), ); - // TODO: dotprod - let p0 = vaddq_s16( - vmull_s8(vget_low_s8(q3bytes_0), vget_low_s8(q8bytes_2.0)), - vmull_s8(vget_high_s8(q3bytes_0), vget_high_s8(q8bytes_2.0)), - ); - let p1 = vaddq_s16( - vmull_s8(vget_low_s8(q3bytes_1), vget_low_s8(q8bytes_2.1)), - vmull_s8(vget_high_s8(q3bytes_1), vget_high_s8(q8bytes_2.1)), - ); - let p2 = vaddq_s16( - vmull_s8(vget_low_s8(q3bytes_2), vget_low_s8(q8bytes_2.2)), - vmull_s8(vget_high_s8(q3bytes_2), vget_high_s8(q8bytes_2.2)), - ); - let p3 = vaddq_s16( - vmull_s8(vget_low_s8(q3bytes_3), vget_low_s8(q8bytes_2.3)), - vmull_s8(vget_high_s8(q3bytes_3), vget_high_s8(q8bytes_2.3)), - ); - isum += vaddvq_s16(p0) as i32 * *scale as i32 - + vaddvq_s16(p1) as i32 * *scale.add(1) as i32 - + vaddvq_s16(p2) as i32 * *scale.add(2) as i32 - + vaddvq_s16(p3) as i32 * *scale.add(3) as i32; + let p0 = vdotq_s32(q3bytes_0, q8bytes_2.0); + let p1 = vdotq_s32(q3bytes_1, q8bytes_2.1); + let p2 = vdotq_s32(q3bytes_2, q8bytes_2.2); + let p3 = vdotq_s32(q3bytes_3, q8bytes_2.3); + isum += vaddvq_s32(p0) * *scale as i32 + + vaddvq_s32(p1) * *scale.add(1) as i32 + + vaddvq_s32(p2) * *scale.add(2) as i32 + + vaddvq_s32(p3) * *scale.add(3) as i32; scale = scale.add(4); if j == 0 { @@ -649,7 +561,6 @@ pub(crate) fn vec_dot_q2k_q8k(n: usize, xs: &[BlockQ2K], ys: &[BlockQ8K]) -> Res let mut is = 0usize; // TODO: dotprod - for _j in 0..QK_K / 128 { let q2bits = vld1q_u8_x2(q2); q2 = q2.add(32); @@ -696,14 +607,7 @@ unsafe fn multiply_accum_with_scale( q2bytes: int8x16x2_t, q8bytes: int8x16x2_t, ) -> i32 { - let p1 = vaddq_s16( - vmull_s8(vget_low_s8(q2bytes.0), vget_low_s8(q8bytes.0)), - vmull_s8(vget_high_s8(q2bytes.0), vget_high_s8(q8bytes.0)), - ); - let p2 = vaddq_s16( - vmull_s8(vget_low_s8(q2bytes.1), vget_low_s8(q8bytes.1)), - vmull_s8(vget_high_s8(q2bytes.1), vget_high_s8(q8bytes.1)), - ); - vaddvq_s16(p1) as i32 * aux[is + index] as i32 - + vaddvq_s16(p2) as i32 * aux[is + 1 + index] as i32 + let p1 = vdotq_s32(q2bytes.0, q8bytes.0); + let p2 = vdotq_s32(q2bytes.1, q8bytes.1); + vaddvq_s32(p1) * aux[is + index] as i32 + vaddvq_s32(p2) * aux[is + 1 + index] as i32 } diff --git a/candle-core/tests/quantized_tests.rs b/candle-core/tests/quantized_tests.rs index 716cca8d..e7a2ea7f 100644 --- a/candle-core/tests/quantized_tests.rs +++ b/candle-core/tests/quantized_tests.rs @@ -1,4 +1,5 @@ use candle_core::{ + bail, quantized::{self, GgmlDType}, test_utils::to_vec2_round, Device, Module, Result, Tensor, @@ -265,7 +266,8 @@ fn compare_with_error(values: &[f32], expected: &[f32], tolerance: f32) { } } -/// Creates a vector simillarly to the one used in GGML unit tests: https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L26-L30 +/// Creates a vector similar to the ones used in GGML unit tests: +/// https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L26-L30 fn create_ggml_like_vector(offset: f32) -> Vec { (0..GGML_TEST_SIZE) .map(|i| 0.1 + 2.0 * (i as f32 + offset).cos()) @@ -284,14 +286,15 @@ fn calculate_rmse(a: &[f32], b: &[f32]) -> f32 { sum / a.len() as f32 } -/// Mirrores the GGML quanitzation unit test: https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L43-L50 +/// Similar to the GGML quantization unit test: +/// https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L43-L50 fn ggml_quantization_error_test(max_error: f32) -> Result<()> { let src = create_ggml_like_vector(0.0); let mut dst = vec![0.0; GGML_TEST_SIZE]; let _quant = quantize_roundtrip::(src.as_slice(), dst.as_mut_slice())?; let error = calculate_rmse(src.as_slice(), dst.as_slice()); if error > max_error { - candle_core::bail!( + bail!( "Quantization error {} exceeds max error {}", error, max_error @@ -487,54 +490,66 @@ fn ggml_reference_matmul_error(dtype: GgmlDType) -> Result { GgmlDType::Q5K => 0.000740, GgmlDType::Q6K => 0.000952, GgmlDType::Q4_0 => 0.001143, - GgmlDType::Q4_1 => 0.007784, + GgmlDType::Q4_1 => 0.008, GgmlDType::Q5_0 => 0.001353, - GgmlDType::Q5_1 => 0.001363, + GgmlDType::Q5_1 => 0.00149, GgmlDType::Q8_0 => 0.000092, // Not from the ggml repo. GgmlDType::Q8K => 0.00065, - _ => candle_core::bail!("No GGML results for quantization type {dtype:?}",), + _ => bail!("No GGML results for quantization type {dtype:?}",), }; Ok(err) } -/// Mirrores the GGML matmul unit test: https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L76-L91 +/// Similar to the GGML matmul unit test: +/// https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L76-L91 fn ggml_matmul_error_test() -> Result<()> { let a = create_ggml_like_vector(0.0); let b = create_ggml_like_vector(1.0); + ggml_matmul_error_test_::(a.as_slice(), b.as_slice(), 1.0)?; + // Another example that is more likely to trigger the overflow reported in #1526 + let a = (0..GGML_TEST_SIZE) + .map(|i| i as f32 / GGML_TEST_SIZE as f32) + .collect::>(); + let b = (0..GGML_TEST_SIZE) + .map(|i| i as f32 / GGML_TEST_SIZE as f32) + .collect::>(); + ggml_matmul_error_test_::(a.as_slice(), b.as_slice(), 2.0)?; + Ok(()) +} + +fn ggml_matmul_error_test_(a: &[f32], b: &[f32], err_m: f32) -> Result<()> { let length = a.len(); let mut a_quant = vec![T::zeros(); length / T::BLCK_SIZE]; let mut b_quant = vec![T::VecDotType::zeros(); length / T::VecDotType::BLCK_SIZE]; - T::from_float(&a, &mut a_quant)?; - T::VecDotType::from_float(&b, &mut b_quant)?; + T::from_float(a, &mut a_quant)?; + T::VecDotType::from_float(b, &mut b_quant)?; let result = T::vec_dot(length, &a_quant, &b_quant)?; let result_unopt = T::vec_dot_unopt(length, &a_quant, &b_quant)?; - let reference_result = vec_dot_reference(&a, &b); + let reference_result = vec_dot_reference(a, b); if (result - result_unopt).abs() / length as f32 > 1e-6 { - candle_core::bail!( + bail!( "the opt and unopt vec-dot returned different values, opt {result}, unopt {result_unopt}" ) } let error = (result - reference_result).abs() / length as f32; - let ggml_error = ggml_reference_matmul_error(T::DTYPE)?; + let ggml_error = ggml_reference_matmul_error(T::DTYPE)? * err_m; if !error.is_finite() || error > GGML_MAX_DOT_PRODUCT_ERROR { - candle_core::bail!( - "Dot product error {error} exceeds max error {GGML_MAX_DOT_PRODUCT_ERROR}", - ); + bail!("Dot product error {error} exceeds max error {GGML_MAX_DOT_PRODUCT_ERROR}",); } // We diverge slightly due to different rounding behavior / f16 to f32 conversions in GGML // => we use a slightly higher error threshold const ERROR_LENIENCY: f32 = 0.00001; if error - ERROR_LENIENCY > ggml_error { - candle_core::bail!( + bail!( "Dot product error {} exceeds ggml reference error {}", error, ggml_error @@ -543,6 +558,16 @@ fn ggml_matmul_error_test() -> Result<()> { Ok(()) } +#[test] +fn quantized_mm() -> Result<()> { + ggml_matmul_error_test::()?; + ggml_matmul_error_test::()?; + ggml_matmul_error_test::()?; + ggml_matmul_error_test::()?; + ggml_matmul_error_test::()?; + Ok(()) +} + /// generates random tensors of size `m x k` and `n x k` and calculates their expected matrix multiplication result. fn get_random_tensors( m: usize,