Lining up the flash attn version with the non-flash one. (#248)

* Move the flash-attn function in the proper crate.

* Causality tweak.
This commit is contained in:
Laurent Mazare
2023-07-26 15:11:45 +01:00
committed by GitHub
parent 46f2d9f0ac
commit f052ba76cb
2 changed files with 28 additions and 12 deletions

View File

@ -146,19 +146,18 @@ struct CausalSelfAttention {
}
#[cfg(feature = "flash-attn")]
fn flash_attn(softmax_scale: f32, q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
q.custom_op3(
k,
v,
candle_flash_attn::FlashHdim32Sm80 {
softmax_scale,
causal: true,
},
)
fn flash_attn(
q: &Tensor,
k: &Tensor,
v: &Tensor,
softmax_scale: f32,
causal: bool,
) -> Result<Tensor> {
candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
}
#[cfg(not(feature = "flash-attn"))]
fn flash_attn(_: f32, _: &Tensor, _: &Tensor, _: &Tensor) -> Result<Tensor> {
fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> {
unimplemented!("compile with '--features flash-attn'")
}
@ -225,7 +224,7 @@ impl CausalSelfAttention {
let k = k.transpose(1, 2)?;
let v = v.transpose(1, 2)?;
let softmax_scale = 1f32 / (self.head_dim as f32).sqrt();
flash_attn(softmax_scale, &q, &k, &v)?.transpose(1, 2)?
flash_attn(&q, &k, &v, softmax_scale, seq_len > 1)?.transpose(1, 2)?
} else {
let in_dtype = q.dtype();
let q = q.to_dtype(DType::F32)?;