From 3ad4770eb61be34e6d2a7914a935b007d8dee49f Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Fri, 12 Apr 2024 09:15:10 +0200 Subject: [PATCH] Use cat for faster MQA computation. (#2043) * Use cat for faster MQA computation. * Move the function to utils + use it in mistral. * Use the shared repeat-kv in a few more models. * Fix. --- candle-transformers/src/models/gemma.rs | 16 ++------------- candle-transformers/src/models/llama.rs | 12 +---------- candle-transformers/src/models/mistral.rs | 16 ++------------- candle-transformers/src/models/mixtral.rs | 16 ++------------- candle-transformers/src/models/mpt.rs | 18 ++--------------- candle-transformers/src/models/phi.rs | 10 +--------- .../src/models/quantized_llama.rs | 20 +++---------------- .../src/models/quantized_mistral.rs | 16 ++------------- .../src/models/quantized_mpt.rs | 4 ++-- .../src/models/quantized_stable_lm.rs | 17 +++------------- candle-transformers/src/models/qwen2.rs | 17 +++------------- candle-transformers/src/models/qwen2_moe.rs | 17 +++------------- candle-transformers/src/models/stable_lm.rs | 17 +++------------- candle-transformers/src/models/starcoder2.rs | 16 ++------------- candle-transformers/src/models/yi.rs | 16 ++------------- candle-transformers/src/utils.rs | 14 +++++++++++++ 16 files changed, 47 insertions(+), 195 deletions(-) diff --git a/candle-transformers/src/models/gemma.rs b/candle-transformers/src/models/gemma.rs index 15e4dccb..58b5f1e1 100644 --- a/candle-transformers/src/models/gemma.rs +++ b/candle-transformers/src/models/gemma.rs @@ -191,18 +191,6 @@ impl Attention { }) } - fn repeat_kv(&self, xs: Tensor) -> Result { - let n_rep = self.num_kv_groups; - if n_rep == 1 { - Ok(xs) - } else { - let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?; - xs.unsqueeze(2)? - .expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))? - .reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim)) - } - } - fn forward( &mut self, xs: &Tensor, @@ -239,8 +227,8 @@ impl Attention { }; self.kv_cache = Some((key_states.clone(), value_states.clone())); - let key_states = self.repeat_kv(key_states)?.contiguous()?; - let value_states = self.repeat_kv(value_states)?.contiguous()?; + let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?; + let value_states = crate::utils::repeat_kv(value_states, self.num_kv_groups)?; let attn_output = { let scale = 1f64 / f64::sqrt(self.head_dim as f64); diff --git a/candle-transformers/src/models/llama.rs b/candle-transformers/src/models/llama.rs index 73671cdc..f3d482eb 100644 --- a/candle-transformers/src/models/llama.rs +++ b/candle-transformers/src/models/llama.rs @@ -256,17 +256,7 @@ impl CausalSelfAttention { } fn repeat_kv(&self, x: Tensor) -> Result { - let n_rep = self.num_attention_heads / self.num_key_value_heads; - if n_rep == 1 { - Ok(x) - } else { - let (b_sz, n_kv_head, seq_len, head_dim) = x.dims4()?; - let x = x - .unsqueeze(2)? - .expand((b_sz, n_kv_head, n_rep, seq_len, head_dim))? - .reshape((b_sz, n_kv_head * n_rep, seq_len, head_dim))?; - Ok(x) - } + crate::utils::repeat_kv(x, self.num_attention_heads / self.num_key_value_heads) } fn load(vb: VarBuilder, cfg: &Config) -> Result { diff --git a/candle-transformers/src/models/mistral.rs b/candle-transformers/src/models/mistral.rs index d899c712..1cb55f9e 100644 --- a/candle-transformers/src/models/mistral.rs +++ b/candle-transformers/src/models/mistral.rs @@ -216,18 +216,6 @@ impl Attention { }) } - fn repeat_kv(&self, xs: Tensor) -> Result { - let n_rep = self.num_kv_groups; - if n_rep == 1 { - Ok(xs) - } else { - let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?; - xs.unsqueeze(2)? - .expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))? - .reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim)) - } - } - fn forward( &mut self, xs: &Tensor, @@ -266,8 +254,8 @@ impl Attention { }; self.kv_cache = Some((key_states.clone(), value_states.clone())); - let key_states = self.repeat_kv(key_states)?; - let value_states = self.repeat_kv(value_states)?; + let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?; + let value_states = crate::utils::repeat_kv(value_states, self.num_kv_groups)?; let attn_output = if self.use_flash_attn { // flash-attn expects (b_sz, seq_len, nheads, head_dim) diff --git a/candle-transformers/src/models/mixtral.rs b/candle-transformers/src/models/mixtral.rs index f69c68e3..a578d6fe 100644 --- a/candle-transformers/src/models/mixtral.rs +++ b/candle-transformers/src/models/mixtral.rs @@ -158,18 +158,6 @@ impl Attention { }) } - fn repeat_kv(&self, xs: Tensor) -> Result { - let n_rep = self.num_kv_groups; - if n_rep == 1 { - Ok(xs) - } else { - let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?; - xs.unsqueeze(2)? - .expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))? - .reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim)) - } - } - fn forward( &mut self, xs: &Tensor, @@ -206,8 +194,8 @@ impl Attention { }; self.kv_cache = Some((key_states.clone(), value_states.clone())); - let key_states = self.repeat_kv(key_states)?; - let value_states = self.repeat_kv(value_states)?; + let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?; + let value_states = crate::utils::repeat_kv(value_states, self.num_kv_groups)?; let attn_output = if self.use_flash_attn { // flash-attn expects (b_sz, seq_len, nheads, head_dim) diff --git a/candle-transformers/src/models/mpt.rs b/candle-transformers/src/models/mpt.rs index 093e177c..d46524fc 100644 --- a/candle-transformers/src/models/mpt.rs +++ b/candle-transformers/src/models/mpt.rs @@ -104,8 +104,8 @@ impl GroupedQueryAttention { }; self.kv_cache = Some((key.clone(), value.clone())); let query = query.contiguous()?; - let key = repeat_kv(key, self.n_heads / self.kv_n_heads)?.contiguous()?; - let value = repeat_kv(value, self.n_heads / self.kv_n_heads)?.contiguous()?; + let key = crate::utils::repeat_kv(key, self.n_heads / self.kv_n_heads)?.contiguous()?; + let value = crate::utils::repeat_kv(value, self.n_heads / self.kv_n_heads)?.contiguous()?; let attn_weights = (query.matmul(&key)? * self.softmax_scale)?; let attn_bias = { let s_q = query.dim(D::Minus2)?; @@ -134,20 +134,6 @@ impl GroupedQueryAttention { } } -// This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). -// The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to -// (batch, num_attention_heads, seqlen, head_dim) -pub(crate) fn repeat_kv(xs: Tensor, n_rep: usize) -> Result { - if n_rep == 1 { - Ok(xs) - } else { - let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?; - xs.unsqueeze(2)? - .expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))? - .reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim)) - } -} - #[derive(Debug, Clone)] struct Ffn { up_proj: Linear, diff --git a/candle-transformers/src/models/phi.rs b/candle-transformers/src/models/phi.rs index 8bf357e7..2c7fccef 100644 --- a/candle-transformers/src/models/phi.rs +++ b/candle-transformers/src/models/phi.rs @@ -174,15 +174,7 @@ impl Attention { } fn repeat_kv(&self, xs: Tensor) -> Result { - let n_rep = self.num_heads / self.num_kv_heads; - if n_rep == 1 { - Ok(xs) - } else { - let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?; - xs.unsqueeze(2)? - .expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))? - .reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim)) - } + crate::utils::repeat_kv(xs, self.num_heads / self.num_kv_heads) } fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result { diff --git a/candle-transformers/src/models/quantized_llama.rs b/candle-transformers/src/models/quantized_llama.rs index e1519b2d..6b326fbe 100644 --- a/candle-transformers/src/models/quantized_llama.rs +++ b/candle-transformers/src/models/quantized_llama.rs @@ -205,9 +205,9 @@ impl LayerWeights { }; self.kv_cache = Some((k.clone(), v.clone())); - // Support for MQA, useful for 70B models. - let k = self.repeat_kv(k)?; - let v = self.repeat_kv(v)?; + // Support for MQA, useful for 70B models and mistral. + let k = crate::utils::repeat_kv(k, self.n_head / self.n_kv_head)?; + let v = crate::utils::repeat_kv(v, self.n_head / self.n_kv_head)?; let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?; let att = match mask { @@ -224,20 +224,6 @@ impl LayerWeights { let y = self.attention_wo.forward(&y)?; Ok(y) } - - fn repeat_kv(&self, x: Tensor) -> Result { - let n_rep = self.n_head / self.n_kv_head; - if n_rep == 1 { - Ok(x) - } else { - let (b_sz, n_kv_head, seq_len, head_dim) = x.dims4()?; - let x = x - .unsqueeze(2)? - .expand((b_sz, n_kv_head, n_rep, seq_len, head_dim))? - .reshape((b_sz, n_kv_head * n_rep, seq_len, head_dim))?; - Ok(x) - } - } } #[derive(Debug, Clone)] diff --git a/candle-transformers/src/models/quantized_mistral.rs b/candle-transformers/src/models/quantized_mistral.rs index e37785de..0583810a 100644 --- a/candle-transformers/src/models/quantized_mistral.rs +++ b/candle-transformers/src/models/quantized_mistral.rs @@ -122,18 +122,6 @@ impl Attention { }) } - fn repeat_kv(&self, xs: Tensor) -> Result { - let n_rep = self.num_kv_groups; - if n_rep == 1 { - Ok(xs) - } else { - let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?; - xs.unsqueeze(2)? - .expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))? - .reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim)) - } - } - fn forward( &mut self, xs: &Tensor, @@ -172,8 +160,8 @@ impl Attention { }; self.kv_cache = Some((key_states.clone(), value_states.clone())); - let key_states = self.repeat_kv(key_states)?; - let value_states = self.repeat_kv(value_states)?; + let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?; + let value_states = crate::utils::repeat_kv(value_states, self.num_kv_groups)?; let attn_output = { let scale = 1f64 / f64::sqrt(self.head_dim as f64); diff --git a/candle-transformers/src/models/quantized_mpt.rs b/candle-transformers/src/models/quantized_mpt.rs index 70a9e125..056fcac2 100644 --- a/candle-transformers/src/models/quantized_mpt.rs +++ b/candle-transformers/src/models/quantized_mpt.rs @@ -71,8 +71,8 @@ impl GroupedQueryAttention { }; self.kv_cache = Some((key.clone(), value.clone())); let query = query.contiguous()?; - let key = super::mpt::repeat_kv(key, self.n_heads / self.kv_n_heads)?.contiguous()?; - let value = super::mpt::repeat_kv(value, self.n_heads / self.kv_n_heads)?.contiguous()?; + let key = crate::utils::repeat_kv(key, self.n_heads / self.kv_n_heads)?.contiguous()?; + let value = crate::utils::repeat_kv(value, self.n_heads / self.kv_n_heads)?.contiguous()?; let attn_weights = (query.matmul(&key)? * self.softmax_scale)?; let attn_bias = { let s_q = query.dim(D::Minus2)?; diff --git a/candle-transformers/src/models/quantized_stable_lm.rs b/candle-transformers/src/models/quantized_stable_lm.rs index 7d4385a7..da447522 100644 --- a/candle-transformers/src/models/quantized_stable_lm.rs +++ b/candle-transformers/src/models/quantized_stable_lm.rs @@ -94,18 +94,6 @@ impl Attention { }) } - fn repeat_kv(&self, xs: Tensor) -> Result { - let n_rep = self.num_kv_groups; - if n_rep == 1 { - Ok(xs) - } else { - let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?; - xs.unsqueeze(2)? - .expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))? - .reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim)) - } - } - fn forward( &mut self, xs: &Tensor, @@ -152,8 +140,9 @@ impl Attention { self.kv_cache = Some((key_states.clone(), value_states.clone())); } - let key_states = self.repeat_kv(key_states)?.contiguous()?; - let value_states = self.repeat_kv(value_states)?.contiguous()?; + let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?.contiguous()?; + let value_states = + crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()?; let attn_output = { let scale = 1f64 / f64::sqrt(self.head_dim as f64); diff --git a/candle-transformers/src/models/qwen2.rs b/candle-transformers/src/models/qwen2.rs index 9a12eba5..06f9069a 100644 --- a/candle-transformers/src/models/qwen2.rs +++ b/candle-transformers/src/models/qwen2.rs @@ -146,18 +146,6 @@ impl Attention { }) } - fn repeat_kv(&self, xs: Tensor) -> Result { - let n_rep = self.num_kv_groups; - if n_rep == 1 { - Ok(xs) - } else { - let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?; - xs.unsqueeze(2)? - .expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))? - .reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim)) - } - } - fn forward( &mut self, xs: &Tensor, @@ -194,8 +182,9 @@ impl Attention { }; self.kv_cache = Some((key_states.clone(), value_states.clone())); - let key_states = self.repeat_kv(key_states)?.contiguous()?; - let value_states = self.repeat_kv(value_states)?.contiguous()?; + let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?.contiguous()?; + let value_states = + crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()?; let attn_output = { let scale = 1f64 / f64::sqrt(self.head_dim as f64); diff --git a/candle-transformers/src/models/qwen2_moe.rs b/candle-transformers/src/models/qwen2_moe.rs index d6566e90..5650e350 100644 --- a/candle-transformers/src/models/qwen2_moe.rs +++ b/candle-transformers/src/models/qwen2_moe.rs @@ -151,18 +151,6 @@ impl Attention { }) } - fn repeat_kv(&self, xs: Tensor) -> Result { - let n_rep = self.num_kv_groups; - if n_rep == 1 { - Ok(xs) - } else { - let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?; - xs.unsqueeze(2)? - .expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))? - .reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim)) - } - } - fn forward( &mut self, xs: &Tensor, @@ -199,8 +187,9 @@ impl Attention { }; self.kv_cache = Some((key_states.clone(), value_states.clone())); - let key_states = self.repeat_kv(key_states)?.contiguous()?; - let value_states = self.repeat_kv(value_states)?.contiguous()?; + let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?.contiguous()?; + let value_states = + crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()?; let attn_output = { let scale = 1f64 / f64::sqrt(self.head_dim as f64); diff --git a/candle-transformers/src/models/stable_lm.rs b/candle-transformers/src/models/stable_lm.rs index a1d58936..2b46e8a1 100644 --- a/candle-transformers/src/models/stable_lm.rs +++ b/candle-transformers/src/models/stable_lm.rs @@ -217,18 +217,6 @@ impl Attention { }) } - fn repeat_kv(&self, xs: Tensor) -> Result { - let n_rep = self.num_kv_groups; - if n_rep == 1 { - Ok(xs) - } else { - let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?; - xs.unsqueeze(2)? - .expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))? - .reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim)) - } - } - fn forward( &mut self, xs: &Tensor, @@ -275,8 +263,9 @@ impl Attention { self.kv_cache = Some((key_states.clone(), value_states.clone())); } - let key_states = self.repeat_kv(key_states)?.contiguous()?; - let value_states = self.repeat_kv(value_states)?.contiguous()?; + let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?.contiguous()?; + let value_states = + crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()?; let attn_output = if self.use_flash_attn { // flash-attn expects (b_sz, seq_len, nheads, head_dim) diff --git a/candle-transformers/src/models/starcoder2.rs b/candle-transformers/src/models/starcoder2.rs index da3f6799..d108d062 100644 --- a/candle-transformers/src/models/starcoder2.rs +++ b/candle-transformers/src/models/starcoder2.rs @@ -139,18 +139,6 @@ impl Attention { }) } - fn repeat_kv(&self, xs: Tensor) -> Result { - let n_rep = self.num_kv_groups; - if n_rep == 1 { - Ok(xs) - } else { - let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?; - xs.unsqueeze(2)? - .expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))? - .reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim)) - } - } - fn forward( &mut self, xs: &Tensor, @@ -187,8 +175,8 @@ impl Attention { }; self.kv_cache = Some((key_states.clone(), value_states.clone())); - let key_states = self.repeat_kv(key_states)?; - let value_states = self.repeat_kv(value_states)?; + let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?; + let value_states = crate::utils::repeat_kv(value_states, self.num_kv_groups)?; let scale = 1f64 / f64::sqrt(self.head_dim as f64); let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?; diff --git a/candle-transformers/src/models/yi.rs b/candle-transformers/src/models/yi.rs index 99d9de1b..df78ddce 100644 --- a/candle-transformers/src/models/yi.rs +++ b/candle-transformers/src/models/yi.rs @@ -175,18 +175,6 @@ impl Attention { }) } - fn repeat_kv(&self, xs: Tensor) -> Result { - let n_rep = self.num_kv_groups; - if n_rep == 1 { - Ok(xs) - } else { - let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?; - xs.unsqueeze(2)? - .expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))? - .reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim)) - } - } - fn forward( &mut self, xs: &Tensor, @@ -223,8 +211,8 @@ impl Attention { }; self.kv_cache = Some((key_states.clone(), value_states.clone())); - let key_states = self.repeat_kv(key_states)?; - let value_states = self.repeat_kv(value_states)?; + let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?; + let value_states = crate::utils::repeat_kv(value_states, self.num_kv_groups)?; let attn_output = { let scale = 1f64 / f64::sqrt(self.head_dim as f64); diff --git a/candle-transformers/src/utils.rs b/candle-transformers/src/utils.rs index d29995ed..17e83694 100644 --- a/candle-transformers/src/utils.rs +++ b/candle-transformers/src/utils.rs @@ -20,3 +20,17 @@ pub fn apply_repeat_penalty(logits: &Tensor, penalty: f32, context: &[u32]) -> R let logits_len = logits.len(); Tensor::from_vec(logits, logits_len, device) } + +/// Repeats a key or value tensor for grouped query attention +/// The input tensor should have a shape `(batch, num_kv_heads, seq_len, head_dim)`, +pub fn repeat_kv(xs: Tensor, n_rep: usize) -> Result { + if n_rep == 1 { + Ok(xs) + } else { + let (b_sz, n_kv_head, seq_len, head_dim) = xs.dims4()?; + // Using cat is faster than a broadcast as it avoids going through a potentially + // strided copy. + // https://github.com/huggingface/candle/pull/2043 + Tensor::cat(&vec![&xs; n_rep], 2)?.reshape((b_sz, n_kv_head * n_rep, seq_len, head_dim)) + } +}