mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
chore: update flash attention kernels (#1518)
* chore: update flash attention kernels * fmt * remove unused kernels * force f32 * correct stride
This commit is contained in:
@ -1,17 +1,15 @@
|
||||
#include "flash_fwd_launch_template.h"
|
||||
|
||||
// void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
// FWD_HEADDIM_SWITCH(params.d, [&] {
|
||||
// run_mha_fwd_<cutlass::half_t, kHeadDim>(params, stream);
|
||||
// });
|
||||
// }
|
||||
|
||||
void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
FP16_SWITCH(!params.is_bf16, [&] {
|
||||
FWD_HEADDIM_SWITCH(params.d, [&] {
|
||||
run_mha_fwd_<elem_type, kHeadDim>(params, stream);
|
||||
});
|
||||
});
|
||||
void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream, bool force_split_kernel=false) {
|
||||
FP16_SWITCH(!params.is_bf16, [&] {
|
||||
FWD_HEADDIM_SWITCH(params.d, [&] {
|
||||
// if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0
|
||||
run_mha_fwd_<elem_type, kHeadDim>(params, stream);
|
||||
// } else {
|
||||
// run_mha_fwd_splitkv_dispatch<elem_type, kHeadDim>(params, stream);
|
||||
// }
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
extern "C" void run_mha(
|
||||
@ -20,6 +18,7 @@ extern "C" void run_mha(
|
||||
void *v_ptr,
|
||||
void *o_ptr,
|
||||
void *softmax_lse_ptr,
|
||||
void *alibi_slopes_ptr,
|
||||
|
||||
int32_t *cu_seqlens_q_ptr,
|
||||
int32_t *cu_seqlens_k_ptr,
|
||||
@ -28,6 +27,7 @@ extern "C" void run_mha(
|
||||
uint32_t k_batch_stride,
|
||||
uint32_t v_batch_stride,
|
||||
uint32_t o_batch_stride,
|
||||
uint32_t alibi_slopes_batch_stride,
|
||||
|
||||
uint32_t q_row_stride,
|
||||
uint32_t k_row_stride,
|
||||
@ -51,8 +51,11 @@ extern "C" void run_mha(
|
||||
uint32_t seqlen_q_rounded,
|
||||
uint32_t seqlen_k_rounded,
|
||||
|
||||
int is_bf16,
|
||||
int is_causal,
|
||||
int is_bf16
|
||||
|
||||
int window_size_left,
|
||||
int window_size_right
|
||||
) {
|
||||
Flash_fwd_params params;
|
||||
// Reset the parameters
|
||||
@ -65,12 +68,14 @@ extern "C" void run_mha(
|
||||
params.o_ptr = o_ptr;
|
||||
|
||||
params.softmax_lse_ptr = softmax_lse_ptr;
|
||||
params.alibi_slopes_ptr = alibi_slopes_ptr;
|
||||
|
||||
// All stride are in elements, not bytes.
|
||||
params.q_batch_stride = q_batch_stride;
|
||||
params.k_batch_stride = k_batch_stride;
|
||||
params.v_batch_stride = v_batch_stride;
|
||||
params.o_batch_stride = o_batch_stride;
|
||||
params.alibi_slopes_batch_stride = alibi_slopes_batch_stride;
|
||||
|
||||
params.q_row_stride = q_row_stride;
|
||||
params.k_row_stride = k_row_stride;
|
||||
@ -92,7 +97,6 @@ extern "C" void run_mha(
|
||||
params.seqlen_k_rounded = seqlen_k_rounded;
|
||||
params.d = d;
|
||||
params.d_rounded = d_rounded;
|
||||
params.is_causal = is_causal;
|
||||
|
||||
// Set the different scale values.
|
||||
params.scale_softmax = softmax_scale;
|
||||
@ -106,6 +110,14 @@ extern "C" void run_mha(
|
||||
params.cu_seqlens_q = cu_seqlens_q_ptr;
|
||||
params.cu_seqlens_k = cu_seqlens_k_ptr;
|
||||
params.p_ptr = nullptr; // used for `return_softmax`.
|
||||
params.seqused_k = nullptr;
|
||||
|
||||
params.is_causal = is_causal;
|
||||
params.window_size_left = window_size_left;
|
||||
params.window_size_right = window_size_right;
|
||||
|
||||
params.is_seqlens_k_cumulative = true;
|
||||
params.num_splits = 1;
|
||||
|
||||
cudaStream_t stream = 0; // Use the default stream.
|
||||
run_mha_fwd(params, stream);
|
||||
|
Reference in New Issue
Block a user