mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 02:16:37 +00:00
Flash-Attn upgrade / SoftCap Candle-FlashAttn [3/n] (#2690)
* update flash-attn v1 * restore: hdim224 * add 224 flash_fwd_template * remove whitespace * softcap is working, including test and api. * make softcap test case better * unpadded lse added
This commit is contained in:
@ -53,6 +53,7 @@ extern "C" void run_mha(
|
||||
|
||||
int is_bf16,
|
||||
int is_causal,
|
||||
int unpadded_lse,
|
||||
|
||||
int window_size_left,
|
||||
int window_size_right,
|
||||
@ -128,6 +129,7 @@ extern "C" void run_mha(
|
||||
|
||||
params.is_seqlens_k_cumulative = true;
|
||||
params.num_splits = 1;
|
||||
params.unpadded_lse = unpadded_lse;
|
||||
|
||||
cudaStream_t stream = 0; // Use the default stream.
|
||||
run_mha_fwd(params, stream);
|
||||
|
@ -42,6 +42,7 @@ extern "C" {
|
||||
|
||||
is_bf16: c_int,
|
||||
is_causal: c_int,
|
||||
unpadded_lse: c_int,
|
||||
|
||||
window_size_left: c_int,
|
||||
window_size_right: c_int,
|
||||
|
@ -200,6 +200,7 @@ impl FlashAttn {
|
||||
/* seqlen_k_rounded */ seqlen_k_rounded as u32,
|
||||
/* is_bf16 */ is_bf16,
|
||||
/* is_causal */ is_causal,
|
||||
/* upadded_lse */ 0,
|
||||
/* window_size_left */ window_size_left,
|
||||
/* window_size_right */ window_size_right,
|
||||
/* softcap */ self.softcap.unwrap_or(0f32),
|
||||
@ -518,7 +519,7 @@ impl FlashAttnVarLen {
|
||||
candle::bail!("the last dim of v must be contiguous {v_stride:?}")
|
||||
}
|
||||
|
||||
let (_total_q, num_heads, head_size_og) = q_l.shape().dims3()?;
|
||||
let (total_q, num_heads, head_size_og) = q_l.shape().dims3()?;
|
||||
let (total_k, num_heads_k, _head_size_og) = k_l.shape().dims3()?;
|
||||
let expected_kv = (total_k, num_heads_k, head_size_og);
|
||||
if expected_kv != k_l.shape().dims3()? {
|
||||
@ -601,9 +602,7 @@ impl FlashAttnVarLen {
|
||||
|
||||
let elem_count = out_shape.elem_count();
|
||||
let dst = unsafe { dev.alloc::<f16>(elem_count) }.w()?;
|
||||
let softmax_lse = dev
|
||||
.alloc_zeros::<f32>(batch_size * num_heads * self.max_seqlen_q)
|
||||
.w()?;
|
||||
let softmax_lse = dev.alloc_zeros::<f32>(num_heads * total_q).w()?;
|
||||
|
||||
let is_bf16 = if is_bf16 { 1 } else { 0 };
|
||||
|
||||
@ -663,6 +662,7 @@ impl FlashAttnVarLen {
|
||||
/* seqlen_k_rounded */ seqlen_k_rounded as u32,
|
||||
/* is_bf16 */ is_bf16,
|
||||
/* is_causal */ is_causal,
|
||||
/* upadded_lse */ 1,
|
||||
/* window_size_left */ window_size_left,
|
||||
/* window_size_right */ window_size_right,
|
||||
/* softcap */ self.softcap.unwrap_or(0.0),
|
||||
|
Reference in New Issue
Block a user