mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Flash attention without padding (varlen). (#281)
* Expose the seqlen variable for flash-attn without padding. * Fix the batched call. * Adapt for the varlen variant. * No need to set the batch strides when in varlen mode. * Add a test (disabled at the moment). * Get the test to work properly.
This commit is contained in:
@ -22,6 +22,9 @@ extern "C" void run_mha(
|
||||
void *o_ptr,
|
||||
void *softmax_lse_ptr,
|
||||
|
||||
int32_t *cu_seqlens_q_ptr,
|
||||
int32_t *cu_seqlens_k_ptr,
|
||||
|
||||
uint32_t q_batch_stride,
|
||||
uint32_t k_batch_stride,
|
||||
uint32_t v_batch_stride,
|
||||
@ -100,9 +103,9 @@ extern "C" void run_mha(
|
||||
params.rp_dropout = 1.f / params.p_dropout;
|
||||
params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax;
|
||||
params.is_bf16 = 0;
|
||||
params.cu_seqlens_q = nullptr;
|
||||
params.cu_seqlens_k = nullptr;
|
||||
params.p_ptr = nullptr;
|
||||
params.cu_seqlens_q = cu_seqlens_q_ptr;
|
||||
params.cu_seqlens_k = cu_seqlens_k_ptr;
|
||||
params.p_ptr = nullptr; // used for `return_softmax`.
|
||||
|
||||
cudaStream_t stream = 0; // Use the default stream.
|
||||
run_mha_fwd(params, stream);
|
||||
|
Reference in New Issue
Block a user