Use flash-attn in gemma. (#2195)

* Use flash-attn in gemma.

* Fix flash-attn for head dim 256.
This commit is contained in:
Laurent Mazare
2024-05-18 19:18:59 +02:00
committed by GitHub
parent eefc1c77ef
commit 7ebc3548e1
4 changed files with 55 additions and 20 deletions

View File

@ -139,7 +139,9 @@ impl FlashAttn {
let elem_count = out_shape.elem_count();
let dst = unsafe { dev.alloc::<T>(elem_count) }.w()?;
let softmax_lse = dev.alloc_zeros::<f32>(b_sz * num_heads * seqlen_q).w()?;
let softmax_lse = dev
.alloc_zeros::<f32>(b_sz * 128 * num_heads * seqlen_q)
.w()?;
let is_bf16 = if is_bf16 { 1 } else { 0 };