From f9ecc8447753d759e776e762ba9309bb90b76bb3 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Tue, 22 Aug 2023 19:41:10 +0100 Subject: [PATCH] GQA support in the quantized model. (#555) * GQA support in the quantized model. * Fix the reshaping. * Fix the main llama model. * Infer the proper gqa from the model kind. --- candle-examples/examples/llama/model.rs | 2 +- candle-examples/examples/quantized/main.rs | 36 +++++++++++++++++++--- 2 files changed, 32 insertions(+), 6 deletions(-) diff --git a/candle-examples/examples/llama/model.rs b/candle-examples/examples/llama/model.rs index 86d13bdb..561c2939 100644 --- a/candle-examples/examples/llama/model.rs +++ b/candle-examples/examples/llama/model.rs @@ -291,7 +291,7 @@ impl CausalSelfAttention { let x = x .unsqueeze(2)? .expand((b_sz, n_kv_head, n_rep, seq_len, head_dim))? - .reshape((b_sz, n_kv_head, n_rep, seq_len, head_dim))?; + .reshape((b_sz, n_kv_head * n_rep, seq_len, head_dim))?; Ok(x) } } diff --git a/candle-examples/examples/quantized/main.rs b/candle-examples/examples/quantized/main.rs index 8411142e..477c695f 100644 --- a/candle-examples/examples/quantized/main.rs +++ b/candle-examples/examples/quantized/main.rs @@ -68,6 +68,7 @@ struct LayerWeights { feed_forward_w3: QMatMul, ffn_norm: RmsNorm, n_head: usize, + n_kv_head: usize, head_dim: usize, cos: Tensor, sin: Tensor, @@ -125,10 +126,10 @@ impl LayerWeights { .reshape((b_sz, seq_len, self.n_head, self.head_dim))? .transpose(1, 2)?; let k = k - .reshape((b_sz, seq_len, self.n_head, self.head_dim))? + .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))? .transpose(1, 2)?; let v = v - .reshape((b_sz, seq_len, self.n_head, self.head_dim))? + .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))? .transpose(1, 2)?; let q = self.apply_rotary_emb(&q, index_pos)?; @@ -144,7 +145,9 @@ impl LayerWeights { }; self.kv_cache = Some((k.clone(), v.clone())); - // If we start supporting MQA, we need to repeat the k and v tensors here. + // Support for MQA, useful for 70B models. + let k = self.repeat_kv(k)?; + let v = self.repeat_kv(v)?; let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?; let mask = mask.broadcast_as(att.shape())?; @@ -156,6 +159,20 @@ impl LayerWeights { let y = self.attention_wo.forward(&y)?; Ok(y) } + + fn repeat_kv(&self, x: Tensor) -> Result { + let n_rep = self.n_head / self.n_kv_head; + if n_rep == 1 { + Ok(x) + } else { + let (b_sz, n_kv_head, seq_len, head_dim) = x.dims4()?; + let x = x + .unsqueeze(2)? + .expand((b_sz, n_kv_head, n_rep, seq_len, head_dim))? + .reshape((b_sz, n_kv_head * n_rep, seq_len, head_dim))?; + Ok(x) + } + } } struct ModelWeights { @@ -179,7 +196,7 @@ impl WeightMap { } impl ModelWeights { - fn new(mut ct: Content) -> Result { + fn new(mut ct: Content, gqa: usize) -> Result { let cpu = &Device::Cpu; let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize; @@ -226,6 +243,7 @@ impl ModelWeights { feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3), ffn_norm: RmsNorm::new(ffn_norm)?, n_head: ct.hparams.n_head as usize, + n_kv_head: ct.hparams.n_head as usize / gqa, head_dim: (ct.hparams.n_embd / ct.hparams.n_head) as usize, cos: cos.clone(), sin: sin.clone(), @@ -347,6 +365,10 @@ struct Args { /// The model size to use. #[arg(long, default_value = "7b")] which: Which, + + /// Group-Query Attention, use 8 for the 70B version of LLaMAv2. + #[arg(long)] + gqa: Option, } impl Args { @@ -468,7 +490,11 @@ fn main() -> anyhow::Result<()> { start.elapsed().as_secs_f32(), ); println!("params: {:?}", model.hparams); - let mut model = ModelWeights::new(model)?; + let default_gqa = match args.which { + Which::L7b | Which::L13b => 1, + Which::L70b => 8, + }; + let mut model = ModelWeights::new(model, args.gqa.unwrap_or(default_gqa))?; println!("model built"); let tokenizer = args.tokenizer()?;