Proper flash-attn parameters. (#244)

* Proper flash-attn parameters.

* Set the flash attention parameters.

* Add more validations.

* Setup the o_ flash attn parameters.

* More flash-attn support.

* Set more flash attn parameters.
This commit is contained in:
Laurent Mazare
2023-07-26 10:13:40 +01:00
committed by GitHub
parent e40b150bbe
commit fa2b64d678
5 changed files with 147 additions and 12 deletions

View File

@ -28,16 +28,22 @@ extern "C" void run_mha(
void *k_ptr,
void *v_ptr,
void *o_ptr,
void *softmax_lse_ptr,
uint32_t q_batch_stride,
uint32_t k_batch_stride,
uint32_t v_batch_stride,
uint32_t o_batch_stride,
uint32_t q_row_stride,
uint32_t k_row_stride,
uint32_t v_row_stride,
uint32_t o_row_stride,
uint32_t q_head_stride,
uint32_t k_head_stride,
uint32_t v_head_stride,
uint32_t o_head_stride,
uint32_t b,
uint32_t h,
@ -61,14 +67,24 @@ extern "C" void run_mha(
params.q_ptr = q_ptr;
params.k_ptr = k_ptr;
params.v_ptr = v_ptr;
params.o_ptr = o_ptr;
params.softmax_lse_ptr = softmax_lse_ptr;
// All stride are in elements, not bytes.
params.q_batch_stride = q_batch_stride;
params.k_batch_stride = k_batch_stride;
params.v_batch_stride = v_batch_stride;
params.o_batch_stride = o_batch_stride;
params.q_row_stride = q_row_stride;
params.k_row_stride = k_row_stride;
params.v_row_stride = v_row_stride;
params.o_row_stride = o_row_stride;
params.q_head_stride = q_head_stride;
params.k_head_stride = k_head_stride;
params.v_head_stride = v_head_stride;
params.o_ptr = o_ptr;
params.o_head_stride = o_head_stride;
// Set the dimensions.
params.b = b;
@ -87,6 +103,11 @@ extern "C" void run_mha(
params.scale_softmax = softmax_scale;
params.scale_softmax_log2 = softmax_scale * M_LOG2E;
params.p_dropout = 1.; // probability to keep
params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0));
params.rp_dropout = 1.f / params.p_dropout;
params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax;
cudaStream_t stream = 0; // Use the default stream.
run_mha_fwd_<cutlass::half_t, 32>(params, stream);
}