mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Proper flash-attn parameters. (#244)
* Proper flash-attn parameters. * Set the flash attention parameters. * Add more validations. * Setup the o_ flash attn parameters. * More flash-attn support. * Set more flash attn parameters.
This commit is contained in:
@ -146,12 +146,19 @@ struct CausalSelfAttention {
|
||||
}
|
||||
|
||||
#[cfg(feature = "flash-attn")]
|
||||
fn flash_attn(q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
|
||||
q.custom_op3(k, v, candle_flash_attn::FlashHdim32Sm80)
|
||||
fn flash_attn(softmax_scale: f32, q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
|
||||
q.custom_op3(
|
||||
k,
|
||||
v,
|
||||
candle_flash_attn::FlashHdim32Sm80 {
|
||||
softmax_scale,
|
||||
causal: true,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "flash-attn"))]
|
||||
fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor) -> Result<Tensor> {
|
||||
fn flash_attn(_: f32, _: &Tensor, _: &Tensor, _: &Tensor) -> Result<Tensor> {
|
||||
unimplemented!("compile with '--features flash-attn'")
|
||||
}
|
||||
|
||||
@ -213,7 +220,8 @@ impl CausalSelfAttention {
|
||||
let v = self.repeat_kv(v)?;
|
||||
|
||||
let y = if self.use_flash_attn {
|
||||
flash_attn(&q, &k, &v)?
|
||||
let softmax_scale = 1f32 / (self.head_dim as f32).sqrt();
|
||||
flash_attn(softmax_scale, &q, &k, &v)?
|
||||
} else {
|
||||
let in_dtype = q.dtype();
|
||||
let q = q.to_dtype(DType::F32)?;
|
||||
|
Reference in New Issue
Block a user