From 23389b1bd777b89fd56c8b0c43a5685c3bc6e9e2 Mon Sep 17 00:00:00 2001 From: laurent Date: Thu, 29 Jun 2023 22:00:57 +0100 Subject: [PATCH 1/3] Enable the KV cache after fixing the caching length and the rope bits. --- candle-core/examples/llama/main.rs | 35 ++++++++++++++++++------------ 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/candle-core/examples/llama/main.rs b/candle-core/examples/llama/main.rs index 5a8a15d3..18c5aafc 100644 --- a/candle-core/examples/llama/main.rs +++ b/candle-core/examples/llama/main.rs @@ -23,8 +23,8 @@ use std::sync::{Arc, Mutex}; mod var_store; mod weights; -const CONTEXT_SIZE: usize = 512; -const USE_KV_CACHE: bool = false; +const MAX_SEQ_LEN: usize = 4096; +const USE_KV_CACHE: bool = true; const START_PROMPT: &str = r" EDWARD: I wonder how our princely father 'scaped, @@ -268,8 +268,9 @@ impl CausalSelfAttention { fn apply_rotary_emb(&self, x: &Tensor, freqs_cis: &Tensor) -> Result { let mut dims = x.dims().to_vec(); - let freqs_cis = if dims[1] < CONTEXT_SIZE { - freqs_cis.narrow(1, CONTEXT_SIZE - dims[1], dims[1])? + let fcis_dims = freqs_cis.dims(); + let freqs_cis = if dims[1] < fcis_dims[1] { + freqs_cis.narrow(1, 0, dims[1])? } else { freqs_cis.clone() }; @@ -311,18 +312,18 @@ impl CausalSelfAttention { if USE_KV_CACHE { let mut cache = self.cache.kvs.lock().unwrap(); if let Some((cache_k, cache_v)) = &cache[block_idx] { - k = Tensor::cat(&[cache_k, &k], 1)?; - v = Tensor::cat(&[cache_v, &v], 1)?; + k = Tensor::cat(&[cache_k, &k], 1)?.contiguous()?; + v = Tensor::cat(&[cache_v, &v], 1)?.contiguous()?; let k_seq_len = k.dims()[1]; - if k_seq_len > CONTEXT_SIZE { + if k_seq_len > MAX_SEQ_LEN { k = k - .narrow(1, k_seq_len - CONTEXT_SIZE, CONTEXT_SIZE)? + .narrow(1, k_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)? .contiguous()? } let v_seq_len = v.dims()[1]; - if v_seq_len > CONTEXT_SIZE { + if v_seq_len > 2 * MAX_SEQ_LEN { v = v - .narrow(1, v_seq_len - CONTEXT_SIZE, CONTEXT_SIZE)? + .narrow(1, v_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)? .contiguous()? } } @@ -405,19 +406,18 @@ impl Llama { } fn precompute_freqs_cis(config: &Config, device: &Device) -> Result { - let seq_len = CONTEXT_SIZE; let n_elem = config.n_embd / config.n_head; let theta: Vec<_> = (0..n_elem) .step_by(2) .map(|i| 1f32 / 10000f32.powf(i as f32 / n_elem as f32)) .collect(); - let arange: Vec<_> = (0..seq_len).map(|c| c as f32).collect(); + let arange: Vec<_> = (0..MAX_SEQ_LEN).map(|c| c as f32).collect(); let theta = Tensor::new(theta.as_slice(), device)?; let arange = Tensor::new(arange.as_slice(), device)?; let idx_theta = arange .reshape((arange.elem_count(), 1))? .matmul(&theta.reshape((1, theta.elem_count()))?)?; - let shape = [1, seq_len, n_elem / 2, 1]; + let shape = [1, MAX_SEQ_LEN, n_elem / 2, 1]; let idx_theta_cos = idx_theta.cos()?.reshape(&shape)?; let idx_theta_sin = idx_theta.sin()?.reshape(&shape)?; let last_dim = idx_theta_cos.rank() - 1; @@ -503,16 +503,23 @@ async fn main() -> Result<()> { let mut new_tokens = vec![]; let mut rng = rand::rngs::StdRng::seed_from_u64(args.seed); let start_gen = std::time::Instant::now(); + let mut index_pos = 0; for index in 0..args.sample_len { let start_gen = std::time::Instant::now(); let context_size = if USE_KV_CACHE && index > 0 { 1 } else { - CONTEXT_SIZE + tokens.len() }; let ctxt = &tokens[tokens.len().saturating_sub(context_size)..]; let input = Tensor::new(ctxt, &device)?; + let freqs_cis = if USE_KV_CACHE { + freqs_cis.narrow(1, index_pos, ctxt.len())? + } else { + freqs_cis.clone() + }; let logits = llama.forward(&input, &freqs_cis)?; + index_pos += ctxt.len(); let next_token = if let Some(temperature) = args.temperature { println!("Sampling with temperature {temperature:?}"); From ae3f202f3bae3d78f02fb2c4cfbe82098bf5c006 Mon Sep 17 00:00:00 2001 From: laurent Date: Thu, 29 Jun 2023 22:12:15 +0100 Subject: [PATCH 2/3] Add a flag. --- candle-core/examples/llama/main.rs | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/candle-core/examples/llama/main.rs b/candle-core/examples/llama/main.rs index 18c5aafc..2ec2a9da 100644 --- a/candle-core/examples/llama/main.rs +++ b/candle-core/examples/llama/main.rs @@ -24,7 +24,6 @@ mod var_store; mod weights; const MAX_SEQ_LEN: usize = 4096; -const USE_KV_CACHE: bool = true; const START_PROMPT: &str = r" EDWARD: I wonder how our princely father 'scaped, @@ -219,15 +218,17 @@ fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result #[derive(Clone)] struct Cache { masks: Arc>>, + use_kv_cache: bool, #[allow(clippy::type_complexity)] kvs: Arc>>>, device: Device, } impl Cache { - fn new(config: &Config, device: &Device) -> Self { + fn new(use_kv_cache: bool, config: &Config, device: &Device) -> Self { Self { masks: Arc::new(Mutex::new(HashMap::new())), + use_kv_cache, kvs: Arc::new(Mutex::new(vec![None; config.n_layer])), device: device.clone(), } @@ -309,7 +310,7 @@ impl CausalSelfAttention { let q = self.apply_rotary_emb(&q, freqs_cis)?; let mut k = self.apply_rotary_emb(&k, freqs_cis)?; - if USE_KV_CACHE { + if self.cache.use_kv_cache { let mut cache = self.cache.kvs.lock().unwrap(); if let Some((cache_k, cache_v)) = &cache[block_idx] { k = Tensor::cat(&[cache_k, &k], 1)?.contiguous()?; @@ -446,6 +447,10 @@ struct Args { /// The length of the sample to generate (in tokens). #[arg(long, default_value_t = 100)] sample_len: usize, + + /// Enable the key-value cache. + #[arg(long, default_value_t = true)] + use_kv_cache: bool, } #[tokio::main] @@ -459,7 +464,7 @@ async fn main() -> Result<()> { Device::new_cuda(0)? }; let config = Config::config_7b(); - let cache = Cache::new(&config, &device); + let cache = Cache::new(args.use_kv_cache, &config, &device); let start = std::time::Instant::now(); let (llama, tokenizer_filename) = match args.npy { Some(npy) => { @@ -506,14 +511,14 @@ async fn main() -> Result<()> { let mut index_pos = 0; for index in 0..args.sample_len { let start_gen = std::time::Instant::now(); - let context_size = if USE_KV_CACHE && index > 0 { + let context_size = if cache.use_kv_cache && index > 0 { 1 } else { tokens.len() }; let ctxt = &tokens[tokens.len().saturating_sub(context_size)..]; let input = Tensor::new(ctxt, &device)?; - let freqs_cis = if USE_KV_CACHE { + let freqs_cis = if cache.use_kv_cache { freqs_cis.narrow(1, index_pos, ctxt.len())? } else { freqs_cis.clone() From f6152e74b63f078bdbdc66ebad904ca2a4e88a53 Mon Sep 17 00:00:00 2001 From: laurent Date: Thu, 29 Jun 2023 22:16:40 +0100 Subject: [PATCH 3/3] Tweak the kv-cache flag. --- candle-core/examples/llama/main.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/candle-core/examples/llama/main.rs b/candle-core/examples/llama/main.rs index 2ec2a9da..fac1e14f 100644 --- a/candle-core/examples/llama/main.rs +++ b/candle-core/examples/llama/main.rs @@ -448,9 +448,9 @@ struct Args { #[arg(long, default_value_t = 100)] sample_len: usize, - /// Enable the key-value cache. - #[arg(long, default_value_t = true)] - use_kv_cache: bool, + /// Disable the key-value cache. + #[arg(long)] + no_kv_cache: bool, } #[tokio::main] @@ -464,7 +464,7 @@ async fn main() -> Result<()> { Device::new_cuda(0)? }; let config = Config::config_7b(); - let cache = Cache::new(args.use_kv_cache, &config, &device); + let cache = Cache::new(!args.no_kv_cache, &config, &device); let start = std::time::Instant::now(); let (llama, tokenizer_filename) = match args.npy { Some(npy) => {