Lining up the flash attn version with the non-flash one. (#248)

* Move the flash-attn function in the proper crate.

* Causality tweak.
This commit is contained in:
Laurent Mazare
2023-07-26 15:11:45 +01:00
committed by GitHub
parent 46f2d9f0ac
commit f052ba76cb
2 changed files with 28 additions and 12 deletions

View File

@ -3,7 +3,7 @@ mod ffi;
use candle::backend::BackendStorage;
use candle::cuda_backend::cudarc::driver::DevicePtr;
use candle::cuda_backend::WrapErr;
use candle::{CpuStorage, Error, Layout, Result, Shape};
use candle::{CpuStorage, Error, Layout, Result, Shape, Tensor};
use half::f16;
pub struct FlashHdim32Sm80 {
@ -144,3 +144,20 @@ impl candle::CustomOp3 for FlashHdim32Sm80 {
Ok((dst, out_shape))
}
}
pub fn flash_attn(
q: &Tensor,
k: &Tensor,
v: &Tensor,
softmax_scale: f32,
causal: bool,
) -> Result<Tensor> {
q.custom_op3(
k,
v,
FlashHdim32Sm80 {
softmax_scale,
causal,
},
)
}