From a9101700b6f87a66a538a87bb97fa7585a165165 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Wed, 16 Aug 2023 14:28:42 +0100 Subject: [PATCH] Add a kv-cache to the quantized llama example. (#466) * Add a kv-cache to the quantized llama example. * Also print the prompt. * Bugfix in q6k dequantizing. * Another bugfix. --- candle-core/src/quantized/k_quants.rs | 8 ++++---- candle-examples/examples/ggml/main.rs | 21 ++++++++++++++++----- 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/candle-core/src/quantized/k_quants.rs b/candle-core/src/quantized/k_quants.rs index f7611897..366eca1e 100644 --- a/candle-core/src/quantized/k_quants.rs +++ b/candle-core/src/quantized/k_quants.rs @@ -70,7 +70,7 @@ const _: () = assert!(std::mem::size_of::() == 24); #[repr(C)] pub struct BlockQ8_0 { d: f16, - qs: [u8; QK8_0], + qs: [i8; QK8_0], } const _: () = assert!(std::mem::size_of::() == 34); @@ -476,14 +476,14 @@ impl GgmlType for BlockQ6K { if k % QK_K != 0 { crate::bail!("dequantize_row_q6k: {k} is not divisible by {QK_K}") } - for x in xs.iter() { + for (idx_x, x) in xs.iter().enumerate() { let d = x.d.to_f32(); let ql = &x.ql; let qh = &x.qh; let sc = &x.scales; for n in (0..QK_K).step_by(128) { let idx = n / 128; - let ys = &mut ys[n..]; + let ys = &mut ys[idx_x * QK_K + n..]; let sc = &sc[8 * idx..]; let ql = &ql[64 * idx..]; let qh = &qh[32 * idx..]; @@ -663,7 +663,7 @@ impl GgmlType for BlockQ8_0 { let id = if d != 0f32 { 1. / d } else { 0. }; ys.d = f16::from_f32(d); for (y, &x) in ys.qs.iter_mut().zip(xs.iter()) { - *y = f32::round(x * id) as u8 + *y = f32::round(x * id) as i8 } } Ok(()) diff --git a/candle-examples/examples/ggml/main.rs b/candle-examples/examples/ggml/main.rs index 912bc53a..20fa94cc 100644 --- a/candle-examples/examples/ggml/main.rs +++ b/candle-examples/examples/ggml/main.rs @@ -52,6 +52,7 @@ struct LayerWeights { head_dim: usize, cos: Tensor, sin: Tensor, + kv_cache: Option<(Tensor, Tensor)>, } fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result { @@ -75,7 +76,7 @@ impl LayerWeights { Ok(rope) } - fn forward_attn(&self, x: &Tensor, mask: &Tensor, index_pos: usize) -> Result { + fn forward_attn(&mut self, x: &Tensor, mask: &Tensor, index_pos: usize) -> Result { let (b_sz, seq_len, n_embd) = x.dims3()?; let q = self.attention_wq.forward(x)?; let k = self.attention_wk.forward(x)?; @@ -94,7 +95,15 @@ impl LayerWeights { let q = self.apply_rotary_emb(&q, index_pos)?; let k = self.apply_rotary_emb(&k, index_pos)?; - // TODO: KV cache. + let (k, v) = match &self.kv_cache { + None => (k, v), + Some((k_cache, v_cache)) => { + let k = Tensor::cat(&[k_cache, &k], 2)?.contiguous()?; + let v = Tensor::cat(&[v_cache, &v], 2)?.contiguous()?; + (k, v) + } + }; + self.kv_cache = Some((k.clone(), v.clone())); // If we start supporting MQA, we need to repeat the k and v tensors here. @@ -181,6 +190,7 @@ impl ModelWeights { head_dim: (ct.hparams.n_embd / ct.hparams.n_head) as usize, cos: cos.clone(), sin: sin.clone(), + kv_cache: None, }) } Ok(Self { @@ -209,7 +219,7 @@ impl ModelWeights { let (_b_sz, seq_len) = x.dims2()?; let mask = self.mask(seq_len)?; let mut layer_in = self.tok_embeddings.forward(x)?; - for (_layer_idx, layer) in self.layers.iter().enumerate() { + for layer in self.layers.iter_mut() { let x = layer_in; let residual = &x; let x = layer.attention_norm.forward(&x)?; @@ -302,8 +312,9 @@ fn main() -> anyhow::Result<()> { .to_vec(); let mut index_pos = 0; let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature); - for _index in 0..args.sample_len { - let context_size = tokens.len(); + print!("{prompt}"); + for index in 0..args.sample_len { + let context_size = if index == 0 { tokens.len() } else { 1 }; let ctxt = &tokens[tokens.len().saturating_sub(context_size)..]; let input = Tensor::new(ctxt, &Device::Cpu)?.unsqueeze(0)?; let logits = model.forward(&input, index_pos)?;