chore: update flash attention kernels (#1518)

* chore: update flash attention kernels

* fmt

* remove unused kernels

* force f32

* correct stride
This commit is contained in:
OlivierDehaene
2024-01-05 18:28:55 +01:00
committed by GitHub
parent 3a7304cb0d
commit 8d1a57c9a0
28 changed files with 1087 additions and 466 deletions

View File

@ -1,17 +1,15 @@
#include "flash_fwd_launch_template.h"
// void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream) {
// FWD_HEADDIM_SWITCH(params.d, [&] {
// run_mha_fwd_<cutlass::half_t, kHeadDim>(params, stream);
// });
// }
void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream) {
FP16_SWITCH(!params.is_bf16, [&] {
FWD_HEADDIM_SWITCH(params.d, [&] {
run_mha_fwd_<elem_type, kHeadDim>(params, stream);
});
});
void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream, bool force_split_kernel=false) {
FP16_SWITCH(!params.is_bf16, [&] {
FWD_HEADDIM_SWITCH(params.d, [&] {
// if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0
run_mha_fwd_<elem_type, kHeadDim>(params, stream);
// } else {
// run_mha_fwd_splitkv_dispatch<elem_type, kHeadDim>(params, stream);
// }
});
});
}
extern "C" void run_mha(
@ -20,6 +18,7 @@ extern "C" void run_mha(
void *v_ptr,
void *o_ptr,
void *softmax_lse_ptr,
void *alibi_slopes_ptr,
int32_t *cu_seqlens_q_ptr,
int32_t *cu_seqlens_k_ptr,
@ -28,6 +27,7 @@ extern "C" void run_mha(
uint32_t k_batch_stride,
uint32_t v_batch_stride,
uint32_t o_batch_stride,
uint32_t alibi_slopes_batch_stride,
uint32_t q_row_stride,
uint32_t k_row_stride,
@ -51,8 +51,11 @@ extern "C" void run_mha(
uint32_t seqlen_q_rounded,
uint32_t seqlen_k_rounded,
int is_bf16,
int is_causal,
int is_bf16
int window_size_left,
int window_size_right
) {
Flash_fwd_params params;
// Reset the parameters
@ -65,12 +68,14 @@ extern "C" void run_mha(
params.o_ptr = o_ptr;
params.softmax_lse_ptr = softmax_lse_ptr;
params.alibi_slopes_ptr = alibi_slopes_ptr;
// All stride are in elements, not bytes.
params.q_batch_stride = q_batch_stride;
params.k_batch_stride = k_batch_stride;
params.v_batch_stride = v_batch_stride;
params.o_batch_stride = o_batch_stride;
params.alibi_slopes_batch_stride = alibi_slopes_batch_stride;
params.q_row_stride = q_row_stride;
params.k_row_stride = k_row_stride;
@ -92,7 +97,6 @@ extern "C" void run_mha(
params.seqlen_k_rounded = seqlen_k_rounded;
params.d = d;
params.d_rounded = d_rounded;
params.is_causal = is_causal;
// Set the different scale values.
params.scale_softmax = softmax_scale;
@ -106,6 +110,14 @@ extern "C" void run_mha(
params.cu_seqlens_q = cu_seqlens_q_ptr;
params.cu_seqlens_k = cu_seqlens_k_ptr;
params.p_ptr = nullptr; // used for `return_softmax`.
params.seqused_k = nullptr;
params.is_causal = is_causal;
params.window_size_left = window_size_left;
params.window_size_right = window_size_right;
params.is_seqlens_k_cumulative = true;
params.num_splits = 1;
cudaStream_t stream = 0; // Use the default stream.
run_mha_fwd(params, stream);