diff --git a/candle-examples/examples/llama/model.rs b/candle-examples/examples/llama/model.rs index 049d0c38..d519cafe 100644 --- a/candle-examples/examples/llama/model.rs +++ b/candle-examples/examples/llama/model.rs @@ -146,19 +146,18 @@ struct CausalSelfAttention { } #[cfg(feature = "flash-attn")] -fn flash_attn(softmax_scale: f32, q: &Tensor, k: &Tensor, v: &Tensor) -> Result { - q.custom_op3( - k, - v, - candle_flash_attn::FlashHdim32Sm80 { - softmax_scale, - causal: true, - }, - ) +fn flash_attn( + q: &Tensor, + k: &Tensor, + v: &Tensor, + softmax_scale: f32, + causal: bool, +) -> Result { + candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal) } #[cfg(not(feature = "flash-attn"))] -fn flash_attn(_: f32, _: &Tensor, _: &Tensor, _: &Tensor) -> Result { +fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result { unimplemented!("compile with '--features flash-attn'") } @@ -225,7 +224,7 @@ impl CausalSelfAttention { let k = k.transpose(1, 2)?; let v = v.transpose(1, 2)?; let softmax_scale = 1f32 / (self.head_dim as f32).sqrt(); - flash_attn(softmax_scale, &q, &k, &v)?.transpose(1, 2)? + flash_attn(&q, &k, &v, softmax_scale, seq_len > 1)?.transpose(1, 2)? } else { let in_dtype = q.dtype(); let q = q.to_dtype(DType::F32)?; diff --git a/candle-flash-attn/src/lib.rs b/candle-flash-attn/src/lib.rs index b159aee2..c2dec7d7 100644 --- a/candle-flash-attn/src/lib.rs +++ b/candle-flash-attn/src/lib.rs @@ -3,7 +3,7 @@ mod ffi; use candle::backend::BackendStorage; use candle::cuda_backend::cudarc::driver::DevicePtr; use candle::cuda_backend::WrapErr; -use candle::{CpuStorage, Error, Layout, Result, Shape}; +use candle::{CpuStorage, Error, Layout, Result, Shape, Tensor}; use half::f16; pub struct FlashHdim32Sm80 { @@ -144,3 +144,20 @@ impl candle::CustomOp3 for FlashHdim32Sm80 { Ok((dst, out_shape)) } } + +pub fn flash_attn( + q: &Tensor, + k: &Tensor, + v: &Tensor, + softmax_scale: f32, + causal: bool, +) -> Result { + q.custom_op3( + k, + v, + FlashHdim32Sm80 { + softmax_scale, + causal, + }, + ) +}