mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Flash-Attn upgrade / SoftCap Candle-FlashAttn [2/n] (#2689)
* update flash-attn v1 * restore: hdim224 * add 224 flash_fwd_template * remove whitespace * softcap is working, including test and api. * make softcap test case better --------- Co-authored-by: laurent <laurent.mazare@gmail.com>
This commit is contained in:
@ -55,7 +55,9 @@ extern "C" void run_mha(
|
||||
int is_causal,
|
||||
|
||||
int window_size_left,
|
||||
int window_size_right
|
||||
int window_size_right,
|
||||
|
||||
float softcap
|
||||
) {
|
||||
Flash_fwd_params params;
|
||||
// Reset the parameters
|
||||
@ -99,8 +101,16 @@ extern "C" void run_mha(
|
||||
params.d_rounded = d_rounded;
|
||||
|
||||
// Set the different scale values.
|
||||
params.scale_softmax = softmax_scale;
|
||||
params.scale_softmax_log2 = softmax_scale * M_LOG2E;
|
||||
if (softcap > 0.0) {
|
||||
params.softcap = softmax_scale / softcap;
|
||||
params.scale_softmax = softcap;
|
||||
params.scale_softmax_log2 = softcap * M_LOG2E;
|
||||
} else{
|
||||
// Remove potential NaN
|
||||
params.softcap = 0.0;
|
||||
params.scale_softmax = softmax_scale;
|
||||
params.scale_softmax_log2 = softmax_scale * M_LOG2E;
|
||||
}
|
||||
|
||||
params.p_dropout = 1.; // probability to keep
|
||||
params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0));
|
||||
|
Reference in New Issue
Block a user