diff --git a/candle-core/examples/llama/main.rs b/candle-core/examples/llama/main.rs index 18c5aafc..2ec2a9da 100644 --- a/candle-core/examples/llama/main.rs +++ b/candle-core/examples/llama/main.rs @@ -24,7 +24,6 @@ mod var_store; mod weights; const MAX_SEQ_LEN: usize = 4096; -const USE_KV_CACHE: bool = true; const START_PROMPT: &str = r" EDWARD: I wonder how our princely father 'scaped, @@ -219,15 +218,17 @@ fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result #[derive(Clone)] struct Cache { masks: Arc>>, + use_kv_cache: bool, #[allow(clippy::type_complexity)] kvs: Arc>>>, device: Device, } impl Cache { - fn new(config: &Config, device: &Device) -> Self { + fn new(use_kv_cache: bool, config: &Config, device: &Device) -> Self { Self { masks: Arc::new(Mutex::new(HashMap::new())), + use_kv_cache, kvs: Arc::new(Mutex::new(vec![None; config.n_layer])), device: device.clone(), } @@ -309,7 +310,7 @@ impl CausalSelfAttention { let q = self.apply_rotary_emb(&q, freqs_cis)?; let mut k = self.apply_rotary_emb(&k, freqs_cis)?; - if USE_KV_CACHE { + if self.cache.use_kv_cache { let mut cache = self.cache.kvs.lock().unwrap(); if let Some((cache_k, cache_v)) = &cache[block_idx] { k = Tensor::cat(&[cache_k, &k], 1)?.contiguous()?; @@ -446,6 +447,10 @@ struct Args { /// The length of the sample to generate (in tokens). #[arg(long, default_value_t = 100)] sample_len: usize, + + /// Enable the key-value cache. + #[arg(long, default_value_t = true)] + use_kv_cache: bool, } #[tokio::main] @@ -459,7 +464,7 @@ async fn main() -> Result<()> { Device::new_cuda(0)? }; let config = Config::config_7b(); - let cache = Cache::new(&config, &device); + let cache = Cache::new(args.use_kv_cache, &config, &device); let start = std::time::Instant::now(); let (llama, tokenizer_filename) = match args.npy { Some(npy) => { @@ -506,14 +511,14 @@ async fn main() -> Result<()> { let mut index_pos = 0; for index in 0..args.sample_len { let start_gen = std::time::Instant::now(); - let context_size = if USE_KV_CACHE && index > 0 { + let context_size = if cache.use_kv_cache && index > 0 { 1 } else { tokens.len() }; let ctxt = &tokens[tokens.len().saturating_sub(context_size)..]; let input = Tensor::new(ctxt, &device)?; - let freqs_cis = if USE_KV_CACHE { + let freqs_cis = if cache.use_kv_cache { freqs_cis.narrow(1, index_pos, ctxt.len())? } else { freqs_cis.clone()