From f052ba76cbf88f8e4f9fe38e76f7a2673da6b5f2 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Wed, 26 Jul 2023 15:11:45 +0100 Subject: [PATCH] Lining up the flash attn version with the non-flash one. (#248) * Move the flash-attn function in the proper crate. * Causality tweak. --- candle-examples/examples/llama/model.rs | 21 ++++++++++----------- candle-flash-attn/src/lib.rs | 19 ++++++++++++++++++- 2 files changed, 28 insertions(+), 12 deletions(-) diff --git a/candle-examples/examples/llama/model.rs b/candle-examples/examples/llama/model.rs index 049d0c38..d519cafe 100644 --- a/candle-examples/examples/llama/model.rs +++ b/candle-examples/examples/llama/model.rs @@ -146,19 +146,18 @@ struct CausalSelfAttention { } #[cfg(feature = "flash-attn")] -fn flash_attn(softmax_scale: f32, q: &Tensor, k: &Tensor, v: &Tensor) -> Result { - q.custom_op3( - k, - v, - candle_flash_attn::FlashHdim32Sm80 { - softmax_scale, - causal: true, - }, - ) +fn flash_attn( + q: &Tensor, + k: &Tensor, + v: &Tensor, + softmax_scale: f32, + causal: bool, +) -> Result { + candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal) } #[cfg(not(feature = "flash-attn"))] -fn flash_attn(_: f32, _: &Tensor, _: &Tensor, _: &Tensor) -> Result { +fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result { unimplemented!("compile with '--features flash-attn'") } @@ -225,7 +224,7 @@ impl CausalSelfAttention { let k = k.transpose(1, 2)?; let v = v.transpose(1, 2)?; let softmax_scale = 1f32 / (self.head_dim as f32).sqrt(); - flash_attn(softmax_scale, &q, &k, &v)?.transpose(1, 2)? + flash_attn(&q, &k, &v, softmax_scale, seq_len > 1)?.transpose(1, 2)? } else { let in_dtype = q.dtype(); let q = q.to_dtype(DType::F32)?; diff --git a/candle-flash-attn/src/lib.rs b/candle-flash-attn/src/lib.rs index b159aee2..c2dec7d7 100644 --- a/candle-flash-attn/src/lib.rs +++ b/candle-flash-attn/src/lib.rs @@ -3,7 +3,7 @@ mod ffi; use candle::backend::BackendStorage; use candle::cuda_backend::cudarc::driver::DevicePtr; use candle::cuda_backend::WrapErr; -use candle::{CpuStorage, Error, Layout, Result, Shape}; +use candle::{CpuStorage, Error, Layout, Result, Shape, Tensor}; use half::f16; pub struct FlashHdim32Sm80 { @@ -144,3 +144,20 @@ impl candle::CustomOp3 for FlashHdim32Sm80 { Ok((dst, out_shape)) } } + +pub fn flash_attn( + q: &Tensor, + k: &Tensor, + v: &Tensor, + softmax_scale: f32, + causal: bool, +) -> Result { + q.custom_op3( + k, + v, + FlashHdim32Sm80 { + softmax_scale, + causal, + }, + ) +}