Files
candle/candle-flash-attn/kernels/flash_api.cu
Laurent Mazare 0ace420e66 Flash attention without padding (varlen). (#281)
* Expose the seqlen variable for flash-attn without padding.

* Fix the batched call.

* Adapt for the varlen variant.

* No need to set the batch strides when in varlen mode.

* Add a test (disabled at the moment).

* Get the test to work properly.
2023-07-31 09:45:39 +01:00

113 lines
3.1 KiB
Plaintext

#include "flash_fwd_launch_template.h"
// TODO: Switch back to handling bf16.
void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream) {
FWD_HEADDIM_SWITCH(params.d, [&] {
run_mha_fwd_<cutlass::half_t, kHeadDim>(params, stream);
});
}
// void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream) {
// FP16_SWITCH(!params.is_bf16, [&] {
// FWD_HEADDIM_SWITCH(params.d, [&] {
// run_mha_fwd_<elem_type, kHeadDim>(params, stream);
// });
// });
// }
extern "C" void run_mha(
void *q_ptr,
void *k_ptr,
void *v_ptr,
void *o_ptr,
void *softmax_lse_ptr,
int32_t *cu_seqlens_q_ptr,
int32_t *cu_seqlens_k_ptr,
uint32_t q_batch_stride,
uint32_t k_batch_stride,
uint32_t v_batch_stride,
uint32_t o_batch_stride,
uint32_t q_row_stride,
uint32_t k_row_stride,
uint32_t v_row_stride,
uint32_t o_row_stride,
uint32_t q_head_stride,
uint32_t k_head_stride,
uint32_t v_head_stride,
uint32_t o_head_stride,
uint32_t b,
uint32_t h,
uint32_t h_k,
uint32_t d,
uint32_t d_rounded,
float softmax_scale,
uint32_t seqlen_q,
uint32_t seqlen_k,
uint32_t seqlen_q_rounded,
uint32_t seqlen_k_rounded,
int is_causal
) {
Flash_fwd_params params;
// Reset the parameters
memset(&params, 0, sizeof(params));
// Set the pointers and strides.
params.q_ptr = q_ptr;
params.k_ptr = k_ptr;
params.v_ptr = v_ptr;
params.o_ptr = o_ptr;
params.softmax_lse_ptr = softmax_lse_ptr;
// All stride are in elements, not bytes.
params.q_batch_stride = q_batch_stride;
params.k_batch_stride = k_batch_stride;
params.v_batch_stride = v_batch_stride;
params.o_batch_stride = o_batch_stride;
params.q_row_stride = q_row_stride;
params.k_row_stride = k_row_stride;
params.v_row_stride = v_row_stride;
params.o_row_stride = o_row_stride;
params.q_head_stride = q_head_stride;
params.k_head_stride = k_head_stride;
params.v_head_stride = v_head_stride;
params.o_head_stride = o_head_stride;
// Set the dimensions.
params.b = b;
params.h = h;
params.h_k = h_k;
params.h_h_k_ratio = h / h_k;
params.seqlen_q = seqlen_q;
params.seqlen_k = seqlen_k;
params.seqlen_q_rounded = seqlen_q_rounded;
params.seqlen_k_rounded = seqlen_k_rounded;
params.d = d;
params.d_rounded = d_rounded;
params.is_causal = is_causal;
// Set the different scale values.
params.scale_softmax = softmax_scale;
params.scale_softmax_log2 = softmax_scale * M_LOG2E;
params.p_dropout = 1.; // probability to keep
params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0));
params.rp_dropout = 1.f / params.p_dropout;
params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax;
params.is_bf16 = 0;
params.cu_seqlens_q = cu_seqlens_q_ptr;
params.cu_seqlens_k = cu_seqlens_k_ptr;
params.p_ptr = nullptr; // used for `return_softmax`.
cudaStream_t stream = 0; // Use the default stream.
run_mha_fwd(params, stream);
}