diff --git a/candle-examples/examples/mistral/main.rs b/candle-examples/examples/mistral/main.rs index b8c74a2f..e0cecf15 100644 --- a/candle-examples/examples/mistral/main.rs +++ b/candle-examples/examples/mistral/main.rs @@ -113,6 +113,9 @@ struct Args { #[arg(long)] tracing: bool, + #[arg(long)] + use_flash_attn: bool, + #[arg(long)] prompt: String, @@ -207,7 +210,7 @@ fn main() -> Result<()> { let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let start = std::time::Instant::now(); - let config = Config::config_7b_v0_1(); + let config = Config::config_7b_v0_1(args.use_flash_attn); let device = candle_examples::device(args.cpu)?; let dtype = if device.is_cuda() { DType::BF16 diff --git a/candle-transformers/src/models/mistral.rs b/candle-transformers/src/models/mistral.rs index 346fda89..a7b4c21b 100644 --- a/candle-transformers/src/models/mistral.rs +++ b/candle-transformers/src/models/mistral.rs @@ -17,10 +17,11 @@ pub struct Config { rms_norm_eps: f64, rope_theta: f64, sliding_window: usize, + use_flash_attn: bool, } impl Config { - pub fn config_7b_v0_1() -> Self { + pub fn config_7b_v0_1(use_flash_attn: bool) -> Self { Self { vocab_size: 32000, hidden_size: 4096, @@ -33,6 +34,7 @@ impl Config { rms_norm_eps: 1e-5, rope_theta: 10_000., sliding_window: 4096, + use_flash_attn, } } } @@ -142,6 +144,22 @@ impl Module for MLP { } } +#[cfg(feature = "flash-attn")] +fn flash_attn( + q: &Tensor, + k: &Tensor, + v: &Tensor, + softmax_scale: f32, + causal: bool, +) -> Result { + candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal) +} + +#[cfg(not(feature = "flash-attn"))] +fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result { + unimplemented!("compile with '--features flash-attn'") +} + #[derive(Debug)] struct Attention { q_proj: Linear, @@ -155,6 +173,7 @@ struct Attention { hidden_size: usize, rotary_emb: Arc, kv_cache: Option<(Tensor, Tensor)>, + use_flash_attn: bool, } impl Attention { @@ -180,6 +199,7 @@ impl Attention { hidden_size: hidden_sz, rotary_emb, kv_cache: None, + use_flash_attn: cfg.use_flash_attn, }) } @@ -234,15 +254,24 @@ impl Attention { let key_states = self.repeat_kv(key_states)?; let value_states = self.repeat_kv(value_states)?; - let scale = 1f64 / f64::sqrt(self.head_dim as f64); - let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?; + let attn_output = if self.use_flash_attn { + // flash-attn expects (b_sz, seq_len, nheads, head_dim) + let q = query_states.transpose(1, 2)?; + let k = key_states.transpose(1, 2)?; + let v = value_states.transpose(1, 2)?; + let softmax_scale = 1f32 / (self.head_dim as f32).sqrt(); + flash_attn(&q, &k, &v, softmax_scale, q_len > 1)?.transpose(1, 2)? + } else { + let scale = 1f64 / f64::sqrt(self.head_dim as f64); + let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?; - let attn_weights = match attention_mask { - None => attn_weights, - Some(mask) => attn_weights.broadcast_add(mask)?, + let attn_weights = match attention_mask { + None => attn_weights, + Some(mask) => attn_weights.broadcast_add(mask)?, + }; + let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?; + attn_weights.matmul(&value_states)? }; - let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?; - let attn_output = attn_weights.matmul(&value_states)?; attn_output .transpose(1, 2)? .reshape((b_sz, q_len, self.hidden_size))?