Flash attention without padding (varlen). (#281)

* Expose the seqlen variable for flash-attn without padding.

* Fix the batched call.

* Adapt for the varlen variant.

* No need to set the batch strides when in varlen mode.

* Add a test (disabled at the moment).

* Get the test to work properly.
This commit is contained in:
Laurent Mazare
2023-07-31 09:45:39 +01:00
committed by GitHub
parent a8d8f9f206
commit 0ace420e66
4 changed files with 283 additions and 4 deletions

View File

@ -22,6 +22,9 @@ extern "C" void run_mha(
void *o_ptr,
void *softmax_lse_ptr,
int32_t *cu_seqlens_q_ptr,
int32_t *cu_seqlens_k_ptr,
uint32_t q_batch_stride,
uint32_t k_batch_stride,
uint32_t v_batch_stride,
@ -100,9 +103,9 @@ extern "C" void run_mha(
params.rp_dropout = 1.f / params.p_dropout;
params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax;
params.is_bf16 = 0;
params.cu_seqlens_q = nullptr;
params.cu_seqlens_k = nullptr;
params.p_ptr = nullptr;
params.cu_seqlens_q = cu_seqlens_q_ptr;
params.cu_seqlens_k = cu_seqlens_k_ptr;
params.p_ptr = nullptr; // used for `return_softmax`.
cudaStream_t stream = 0; // Use the default stream.
run_mha_fwd(params, stream);