diff --git a/candle-examples/examples/stable-lm/main.rs b/candle-examples/examples/stable-lm/main.rs index 45051af9..95521265 100644 --- a/candle-examples/examples/stable-lm/main.rs +++ b/candle-examples/examples/stable-lm/main.rs @@ -220,7 +220,7 @@ fn main() -> Result<()> { let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let start = std::time::Instant::now(); - let config = Config::stablelm_3b_4e1t(); + let config = Config::stablelm_3b_4e1t(args.use_flash_attn); let (model, device) = { let device = candle_examples::device(args.cpu)?; let dtype = if device.is_cuda() { diff --git a/candle-transformers/src/models/stable_lm.rs b/candle-transformers/src/models/stable_lm.rs index e86f8877..87e72396 100644 --- a/candle-transformers/src/models/stable_lm.rs +++ b/candle-transformers/src/models/stable_lm.rs @@ -19,10 +19,11 @@ pub struct Config { pub(crate) max_position_embeddings: usize, pub(crate) norm_eps: f64, pub(crate) use_cache: bool, + pub(crate) use_flash_attn: bool, } impl Config { - pub fn stablelm_3b_4e1t() -> Self { + pub fn stablelm_3b_4e1t(use_flash_attn: bool) -> Self { Self { vocab_size: 50304, intermediate_size: 6912, @@ -36,6 +37,7 @@ impl Config { max_position_embeddings: 4096, norm_eps: 1e-5, use_cache: true, + use_flash_attn, } } @@ -134,6 +136,22 @@ impl Module for MLP { } } +#[cfg(feature = "flash-attn")] +fn flash_attn( + q: &Tensor, + k: &Tensor, + v: &Tensor, + softmax_scale: f32, + causal: bool, +) -> Result { + candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal) +} + +#[cfg(not(feature = "flash-attn"))] +fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result { + unimplemented!("compile with '--features flash-attn'") +} + #[derive(Debug)] struct Attention { q_proj: Linear, @@ -149,6 +167,7 @@ struct Attention { kv_cache: Option<(Tensor, Tensor)>, use_cache: bool, rotary_ndims: usize, + use_flash_attn: bool, } impl Attention { @@ -175,6 +194,7 @@ impl Attention { kv_cache: None, use_cache: cfg.use_cache, rotary_ndims: cfg.rotary_ndims(), + use_flash_attn: cfg.use_flash_attn, }) } @@ -238,7 +258,14 @@ impl Attention { let key_states = self.repeat_kv(key_states)?.contiguous()?; let value_states = self.repeat_kv(value_states)?.contiguous()?; - let attn_output = { + let attn_output = if self.use_flash_attn { + // flash-attn expects (b_sz, seq_len, nheads, head_dim) + let q = query_states.transpose(1, 2)?; + let k = key_states.transpose(1, 2)?; + let v = value_states.transpose(1, 2)?; + let softmax_scale = 1f32 / (self.head_dim as f32).sqrt(); + flash_attn(&q, &k, &v, softmax_scale, q_len > 1)?.transpose(1, 2)? + } else { let scale = 1f64 / f64::sqrt(self.head_dim as f64); let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?;