mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Lining up the flash attn version with the non-flash one. (#248)
* Move the flash-attn function in the proper crate. * Causality tweak.
This commit is contained in:
@ -3,7 +3,7 @@ mod ffi;
|
||||
use candle::backend::BackendStorage;
|
||||
use candle::cuda_backend::cudarc::driver::DevicePtr;
|
||||
use candle::cuda_backend::WrapErr;
|
||||
use candle::{CpuStorage, Error, Layout, Result, Shape};
|
||||
use candle::{CpuStorage, Error, Layout, Result, Shape, Tensor};
|
||||
use half::f16;
|
||||
|
||||
pub struct FlashHdim32Sm80 {
|
||||
@ -144,3 +144,20 @@ impl candle::CustomOp3 for FlashHdim32Sm80 {
|
||||
Ok((dst, out_shape))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn flash_attn(
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
v: &Tensor,
|
||||
softmax_scale: f32,
|
||||
causal: bool,
|
||||
) -> Result<Tensor> {
|
||||
q.custom_op3(
|
||||
k,
|
||||
v,
|
||||
FlashHdim32Sm80 {
|
||||
softmax_scale,
|
||||
causal,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
Reference in New Issue
Block a user