mirror of
https://github.com/huggingface/candle.git
synced 2025-06-22 12:28:06 +00:00
Use flash-attn in gemma. (#2195)
* Use flash-attn in gemma. * Fix flash-attn for head dim 256.
This commit is contained in:
@ -139,7 +139,9 @@ impl FlashAttn {
|
||||
|
||||
let elem_count = out_shape.elem_count();
|
||||
let dst = unsafe { dev.alloc::<T>(elem_count) }.w()?;
|
||||
let softmax_lse = dev.alloc_zeros::<f32>(b_sz * num_heads * seqlen_q).w()?;
|
||||
let softmax_lse = dev
|
||||
.alloc_zeros::<f32>(b_sz * 128 * num_heads * seqlen_q)
|
||||
.w()?;
|
||||
|
||||
let is_bf16 = if is_bf16 { 1 } else { 0 };
|
||||
|
||||
|
Reference in New Issue
Block a user