Proper flash-attn parameters. (#244)

* Proper flash-attn parameters.

* Set the flash attention parameters.

* Add more validations.

* Setup the o_ flash attn parameters.

* More flash-attn support.

* Set more flash attn parameters.
This commit is contained in:
Laurent Mazare
2023-07-26 10:13:40 +01:00
committed by GitHub
parent e40b150bbe
commit fa2b64d678
5 changed files with 147 additions and 12 deletions

View File

@ -146,12 +146,19 @@ struct CausalSelfAttention {
}
#[cfg(feature = "flash-attn")]
fn flash_attn(q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
q.custom_op3(k, v, candle_flash_attn::FlashHdim32Sm80)
fn flash_attn(softmax_scale: f32, q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
q.custom_op3(
k,
v,
candle_flash_attn::FlashHdim32Sm80 {
softmax_scale,
causal: true,
},
)
}
#[cfg(not(feature = "flash-attn"))]
fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor) -> Result<Tensor> {
fn flash_attn(_: f32, _: &Tensor, _: &Tensor, _: &Tensor) -> Result<Tensor> {
unimplemented!("compile with '--features flash-attn'")
}
@ -213,7 +220,8 @@ impl CausalSelfAttention {
let v = self.repeat_kv(v)?;
let y = if self.use_flash_attn {
flash_attn(&q, &k, &v)?
let softmax_scale = 1f32 / (self.head_dim as f32).sqrt();
flash_attn(softmax_scale, &q, &k, &v)?
} else {
let in_dtype = q.dtype();
let q = q.to_dtype(DType::F32)?;