mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Lining up the flash attn version with the non-flash one. (#248)
* Move the flash-attn function in the proper crate. * Causality tweak.
This commit is contained in:
@ -146,19 +146,18 @@ struct CausalSelfAttention {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "flash-attn")]
|
#[cfg(feature = "flash-attn")]
|
||||||
fn flash_attn(softmax_scale: f32, q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
|
fn flash_attn(
|
||||||
q.custom_op3(
|
q: &Tensor,
|
||||||
k,
|
k: &Tensor,
|
||||||
v,
|
v: &Tensor,
|
||||||
candle_flash_attn::FlashHdim32Sm80 {
|
softmax_scale: f32,
|
||||||
softmax_scale,
|
causal: bool,
|
||||||
causal: true,
|
) -> Result<Tensor> {
|
||||||
},
|
candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(not(feature = "flash-attn"))]
|
#[cfg(not(feature = "flash-attn"))]
|
||||||
fn flash_attn(_: f32, _: &Tensor, _: &Tensor, _: &Tensor) -> Result<Tensor> {
|
fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> {
|
||||||
unimplemented!("compile with '--features flash-attn'")
|
unimplemented!("compile with '--features flash-attn'")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -225,7 +224,7 @@ impl CausalSelfAttention {
|
|||||||
let k = k.transpose(1, 2)?;
|
let k = k.transpose(1, 2)?;
|
||||||
let v = v.transpose(1, 2)?;
|
let v = v.transpose(1, 2)?;
|
||||||
let softmax_scale = 1f32 / (self.head_dim as f32).sqrt();
|
let softmax_scale = 1f32 / (self.head_dim as f32).sqrt();
|
||||||
flash_attn(softmax_scale, &q, &k, &v)?.transpose(1, 2)?
|
flash_attn(&q, &k, &v, softmax_scale, seq_len > 1)?.transpose(1, 2)?
|
||||||
} else {
|
} else {
|
||||||
let in_dtype = q.dtype();
|
let in_dtype = q.dtype();
|
||||||
let q = q.to_dtype(DType::F32)?;
|
let q = q.to_dtype(DType::F32)?;
|
||||||
|
@ -3,7 +3,7 @@ mod ffi;
|
|||||||
use candle::backend::BackendStorage;
|
use candle::backend::BackendStorage;
|
||||||
use candle::cuda_backend::cudarc::driver::DevicePtr;
|
use candle::cuda_backend::cudarc::driver::DevicePtr;
|
||||||
use candle::cuda_backend::WrapErr;
|
use candle::cuda_backend::WrapErr;
|
||||||
use candle::{CpuStorage, Error, Layout, Result, Shape};
|
use candle::{CpuStorage, Error, Layout, Result, Shape, Tensor};
|
||||||
use half::f16;
|
use half::f16;
|
||||||
|
|
||||||
pub struct FlashHdim32Sm80 {
|
pub struct FlashHdim32Sm80 {
|
||||||
@ -144,3 +144,20 @@ impl candle::CustomOp3 for FlashHdim32Sm80 {
|
|||||||
Ok((dst, out_shape))
|
Ok((dst, out_shape))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn flash_attn(
|
||||||
|
q: &Tensor,
|
||||||
|
k: &Tensor,
|
||||||
|
v: &Tensor,
|
||||||
|
softmax_scale: f32,
|
||||||
|
causal: bool,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
q.custom_op3(
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
FlashHdim32Sm80 {
|
||||||
|
softmax_scale,
|
||||||
|
causal,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user