mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Lining up the flash attn version with the non-flash one. (#248)
* Move the flash-attn function in the proper crate. * Causality tweak.
This commit is contained in:
@ -146,19 +146,18 @@ struct CausalSelfAttention {
|
||||
}
|
||||
|
||||
#[cfg(feature = "flash-attn")]
|
||||
fn flash_attn(softmax_scale: f32, q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
|
||||
q.custom_op3(
|
||||
k,
|
||||
v,
|
||||
candle_flash_attn::FlashHdim32Sm80 {
|
||||
softmax_scale,
|
||||
causal: true,
|
||||
},
|
||||
)
|
||||
fn flash_attn(
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
v: &Tensor,
|
||||
softmax_scale: f32,
|
||||
causal: bool,
|
||||
) -> Result<Tensor> {
|
||||
candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "flash-attn"))]
|
||||
fn flash_attn(_: f32, _: &Tensor, _: &Tensor, _: &Tensor) -> Result<Tensor> {
|
||||
fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> {
|
||||
unimplemented!("compile with '--features flash-attn'")
|
||||
}
|
||||
|
||||
@ -225,7 +224,7 @@ impl CausalSelfAttention {
|
||||
let k = k.transpose(1, 2)?;
|
||||
let v = v.transpose(1, 2)?;
|
||||
let softmax_scale = 1f32 / (self.head_dim as f32).sqrt();
|
||||
flash_attn(softmax_scale, &q, &k, &v)?.transpose(1, 2)?
|
||||
flash_attn(&q, &k, &v, softmax_scale, seq_len > 1)?.transpose(1, 2)?
|
||||
} else {
|
||||
let in_dtype = q.dtype();
|
||||
let q = q.to_dtype(DType::F32)?;
|
||||
|
Reference in New Issue
Block a user