From 28057781aa7d77c602b1ec89838db90bc82ecf8a Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Thu, 22 Feb 2024 12:04:33 +0100 Subject: [PATCH] Make the cache for the llama model explicit too. (#1745) --- candle-examples/examples/llama/main.rs | 6 +-- candle-transformers/src/models/llama.rs | 70 ++++++++++++++----------- 2 files changed, 41 insertions(+), 35 deletions(-) diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs index e95321c7..f7998396 100644 --- a/candle-examples/examples/llama/main.rs +++ b/candle-examples/examples/llama/main.rs @@ -120,7 +120,7 @@ fn main() -> Result<()> { Some(dtype) => bail!("Unsupported dtype {dtype}"), None => DType::F16, }; - let (llama, tokenizer_filename, cache) = { + let (llama, tokenizer_filename, mut cache) = { let api = Api::new()?; let model_id = args.model_id.unwrap_or_else(|| match args.which { Which::V1 => "Narsil/amall-7b".to_string(), @@ -146,7 +146,7 @@ fn main() -> Result<()> { let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?; let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? }; - (Llama::load(vb, &cache, &config)?, tokenizer_filename, cache) + (Llama::load(vb, &config)?, tokenizer_filename, cache) }; let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let eos_token_id = tokenizer.token_to_id(EOS_TOKEN); @@ -172,7 +172,7 @@ fn main() -> Result<()> { }; let ctxt = &tokens[tokens.len().saturating_sub(context_size)..]; let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?; - let logits = llama.forward(&input, context_index)?; + let logits = llama.forward(&input, context_index, &mut cache)?; let logits = logits.squeeze(0)?; let logits = if args.repeat_penalty == 1. { logits diff --git a/candle-transformers/src/models/llama.rs b/candle-transformers/src/models/llama.rs index f8126394..a091d3eb 100644 --- a/candle-transformers/src/models/llama.rs +++ b/candle-transformers/src/models/llama.rs @@ -2,7 +2,6 @@ use super::with_tracing::{linear_no_bias as linear, Linear}; use candle::{DType, Device, IndexOp, Result, Tensor, D}; use candle_nn::{embedding, Embedding, Module, VarBuilder}; use std::collections::HashMap; -use std::sync::{Arc, Mutex}; pub const MAX_SEQ_LEN: usize = 4096; @@ -84,10 +83,9 @@ impl Config { #[derive(Debug, Clone)] pub struct Cache { - masks: Arc>>, + masks: HashMap, pub use_kv_cache: bool, - #[allow(clippy::type_complexity)] - kvs: Arc>>>, + kvs: Vec>, cos: Tensor, sin: Tensor, device: Device, @@ -112,25 +110,24 @@ impl Cache { let cos = idx_theta.cos()?.to_dtype(dtype)?; let sin = idx_theta.sin()?.to_dtype(dtype)?; Ok(Self { - masks: Arc::new(Mutex::new(HashMap::new())), + masks: HashMap::new(), use_kv_cache, - kvs: Arc::new(Mutex::new(vec![None; config.num_hidden_layers])), + kvs: vec![None; config.num_hidden_layers], device: device.clone(), cos, sin, }) } - fn mask(&self, t: usize) -> Result { - let mut masks = self.masks.lock().unwrap(); - if let Some(mask) = masks.get(&t) { + fn mask(&mut self, t: usize) -> Result { + if let Some(mask) = self.masks.get(&t) { Ok(mask.clone()) } else { let mask: Vec<_> = (0..t) .flat_map(|i| (0..t).map(move |j| u8::from(j > i))) .collect(); let mask = Tensor::from_slice(&mask, (t, t), &self.device)?; - masks.insert(t, mask.clone()); + self.masks.insert(t, mask.clone()); Ok(mask) } } @@ -164,7 +161,6 @@ struct CausalSelfAttention { num_attention_heads: usize, num_key_value_heads: usize, head_dim: usize, - cache: Cache, use_flash_attn: bool, span: tracing::Span, span_rot: tracing::Span, @@ -187,11 +183,11 @@ fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result Result { + fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize, cache: &Cache) -> Result { let _enter = self.span_rot.enter(); let (b_sz, _, seq_len, hidden_size) = x.dims4()?; - let cos = self.cache.cos.narrow(0, index_pos, seq_len)?; - let sin = self.cache.sin.narrow(0, index_pos, seq_len)?; + let cos = cache.cos.narrow(0, index_pos, seq_len)?; + let sin = cache.sin.narrow(0, index_pos, seq_len)?; let cos = cos.broadcast_as((b_sz, 1, seq_len, hidden_size))?; let sin = sin.broadcast_as((b_sz, 1, seq_len, hidden_size))?; let x1 = x.narrow(D::Minus1, 0, hidden_size / 2)?; @@ -201,7 +197,13 @@ impl CausalSelfAttention { Ok(rope) } - fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result { + fn forward( + &self, + x: &Tensor, + index_pos: usize, + block_idx: usize, + cache: &mut Cache, + ) -> Result { let _enter = self.span.enter(); let (b_sz, seq_len, hidden_size) = x.dims3()?; let q = self.q_proj.forward(x)?; @@ -218,12 +220,11 @@ impl CausalSelfAttention { .reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))? .transpose(1, 2)?; - let q = self.apply_rotary_emb(&q, index_pos)?; - let mut k = self.apply_rotary_emb(&k, index_pos)?; + let q = self.apply_rotary_emb(&q, index_pos, cache)?; + let mut k = self.apply_rotary_emb(&k, index_pos, cache)?; - if self.cache.use_kv_cache { - let mut cache = self.cache.kvs.lock().unwrap(); - if let Some((cache_k, cache_v)) = &cache[block_idx] { + if cache.use_kv_cache { + if let Some((cache_k, cache_v)) = &cache.kvs[block_idx] { k = Tensor::cat(&[cache_k, &k], 2)?.contiguous()?; v = Tensor::cat(&[cache_v, &v], 2)?.contiguous()?; let k_seq_len = k.dims()[1]; @@ -239,7 +240,7 @@ impl CausalSelfAttention { .contiguous()? } } - cache[block_idx] = Some((k.clone(), v.clone())) + cache.kvs[block_idx] = Some((k.clone(), v.clone())) } let k = self.repeat_kv(k)?; @@ -258,7 +259,7 @@ impl CausalSelfAttention { let k = k.to_dtype(DType::F32)?; let v = v.to_dtype(DType::F32)?; let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?; - let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?; + let mask = cache.mask(seq_len)?.broadcast_as(att.shape())?; let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?; let att = candle_nn::ops::softmax(&att, D::Minus1)?; // Convert to contiguous as matmul doesn't support strided vs for now. @@ -283,7 +284,7 @@ impl CausalSelfAttention { } } - fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result { + fn load(vb: VarBuilder, cfg: &Config) -> Result { let span = tracing::span!(tracing::Level::TRACE, "attn"); let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot"); let size_in = cfg.hidden_size; @@ -301,7 +302,6 @@ impl CausalSelfAttention { num_attention_heads: cfg.num_attention_heads, num_key_value_heads: cfg.num_key_value_heads, head_dim: cfg.hidden_size / cfg.num_attention_heads, - cache: cache.clone(), use_flash_attn: cfg.use_flash_attn, span, span_rot, @@ -357,19 +357,25 @@ struct Block { } impl Block { - fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result { + fn forward( + &self, + x: &Tensor, + index_pos: usize, + block_idx: usize, + cache: &mut Cache, + ) -> Result { let _enter = self.span.enter(); let residual = x; let x = self.rms_1.forward(x)?; - let x = (self.attn.forward(&x, index_pos, block_idx)? + residual)?; + let x = (self.attn.forward(&x, index_pos, block_idx, cache)? + residual)?; let residual = &x; let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?; Ok(x) } - fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result { + fn load(vb: VarBuilder, cfg: &Config) -> Result { let span = tracing::span!(tracing::Level::TRACE, "block"); - let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg)?; + let attn = CausalSelfAttention::load(vb.pp("self_attn"), cfg)?; let mlp = Mlp::load(vb.pp("mlp"), cfg)?; let rms_1 = RmsNorm::load(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?; let rms_2 = RmsNorm::load( @@ -396,11 +402,11 @@ pub struct Llama { } impl Llama { - pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result { + pub fn forward(&self, x: &Tensor, index_pos: usize, cache: &mut Cache) -> Result { let (_b_sz, seq_len) = x.dims2()?; let mut x = self.wte.forward(x)?; for (block_idx, block) in self.blocks.iter().enumerate() { - x = block.forward(&x, index_pos, block_idx)?; + x = block.forward(&x, index_pos, block_idx, cache)?; } let x = self.ln_f.forward(&x)?; let x = x.i((.., seq_len - 1, ..))?; @@ -408,12 +414,12 @@ impl Llama { logits.to_dtype(DType::F32) } - pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result { + pub fn load(vb: VarBuilder, cfg: &Config) -> Result { let wte = embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("model.embed_tokens"))?; let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?; let ln_f = RmsNorm::load(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?; let blocks: Vec<_> = (0..cfg.num_hidden_layers) - .map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), cache, cfg).unwrap()) + .map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), cfg).unwrap()) .collect(); Ok(Self {