Flash-Attn upgrade / SoftCap Candle-FlashAttn [2/n] (#2689)

* update flash-attn v1

* restore: hdim224

* add 224 flash_fwd_template

* remove whitespace

* softcap is working, including test and api.

* make softcap test case better

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
This commit is contained in:
Michael Feil
2024-12-31 09:41:23 +01:00
committed by GitHub
parent 71cd6d5533
commit a594ef669c
4 changed files with 182 additions and 3 deletions

View File

@ -55,7 +55,9 @@ extern "C" void run_mha(
int is_causal,
int window_size_left,
int window_size_right
int window_size_right,
float softcap
) {
Flash_fwd_params params;
// Reset the parameters
@ -99,8 +101,16 @@ extern "C" void run_mha(
params.d_rounded = d_rounded;
// Set the different scale values.
params.scale_softmax = softmax_scale;
params.scale_softmax_log2 = softmax_scale * M_LOG2E;
if (softcap > 0.0) {
params.softcap = softmax_scale / softcap;
params.scale_softmax = softcap;
params.scale_softmax_log2 = softcap * M_LOG2E;
} else{
// Remove potential NaN
params.softcap = 0.0;
params.scale_softmax = softmax_scale;
params.scale_softmax_log2 = softmax_scale * M_LOG2E;
}
params.p_dropout = 1.; // probability to keep
params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0));