diff --git a/candle-core/examples/llama/main.rs b/candle-core/examples/llama/main.rs index 32e4b746..baf0cdb8 100644 --- a/candle-core/examples/llama/main.rs +++ b/candle-core/examples/llama/main.rs @@ -15,6 +15,8 @@ use anyhow::{Error as E, Result}; use clap::Parser; use candle::{DType, Device, Tensor}; +use std::collections::HashMap; +use std::sync::{Arc, Mutex}; mod var_store; use var_store::VarBuilder; @@ -231,15 +233,50 @@ fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result Ok(m) } +#[derive(Clone)] +struct Cache { + masks: Arc>>, + device: Device, +} + +impl Cache { + fn new(device: &Device) -> Self { + Self { + masks: Arc::new(Mutex::new(HashMap::new())), + device: device.clone(), + } + } + + fn mask(&self, t: usize) -> Result { + let mut masks = self.masks.lock().unwrap(); + if let Some(mask) = masks.get(&t) { + Ok(mask.clone()) + } else { + // TODO: If we support bool or u8 tensors, this would be better. + let mask: Vec<_> = (0..t) + .flat_map(|i| (0..t).map(move |j| u32::from(j > i))) + .collect(); + // Once lower_triangle is available, use the following: + //let mask = Tensor::new(1u32, &device)? + // .broadcast_as(&[t, t])? + // .lower_triangle()? + let mask = Tensor::from_slice(&mask, (t, t), &self.device)?; + masks.insert(t, mask.clone()); + Ok(mask) + } + } +} + struct CausalSelfAttention { c_attn: Linear, c_proj: Linear, n_head: usize, n_embd: usize, + cache: Cache, } impl CausalSelfAttention { - fn new(vb: VarBuilder, n_head: usize, n_embd: usize) -> Result { + fn new(vb: VarBuilder, n_head: usize, n_embd: usize, cache: &Cache) -> Result { let c_attn = Linear::new_no_bias(&vb / "c_attn", n_embd, 3 * n_embd)?; let c_proj = Linear::new_no_bias(&vb / "c_proj", n_embd, n_embd)?; Ok(Self { @@ -247,6 +284,7 @@ impl CausalSelfAttention { c_proj, n_head, n_embd, + cache: cache.clone(), }) } @@ -292,16 +330,7 @@ impl CausalSelfAttention { let k = self.apply_rotary_emb(&k, freqs_cis)?; let k_shape = k.shape(); let att = (q.matmul(&k.t()?)? / (*k_shape.dims().last().unwrap() as f64).sqrt())?; - let device = x.device(); - // TODO: If we support bool or u8 tensors, this would be better. - let mask: Vec<_> = (0..t) - .flat_map(|i| (0..t).map(move |j| u32::from(j > i))) - .collect(); - // Once lower_triangle is available, use the following: - //let mask = Tensor::new(1u32, &device)? - // .broadcast_as(&[t, t])? - // .lower_triangle()? - let mask = Tensor::from_slice(&mask, (t, t), &device)?.broadcast_as(att.shape())?; + let mask = self.cache.mask(t)?.broadcast_as(att.shape())?; let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?; let att = att.softmax(att.rank() - 1)?; // Convert to contiguous as matmul doesn't support strided vs for now. @@ -320,9 +349,9 @@ struct Block { } impl Block { - fn new(vb: VarBuilder, config: &Config) -> Result { + fn new(vb: VarBuilder, cache: &Cache, config: &Config) -> Result { let rms_1 = RmsNorm::new(&vb / "rms_1", config.n_embd)?; - let attn = CausalSelfAttention::new(&vb / "attn", config.n_head, config.n_embd)?; + let attn = CausalSelfAttention::new(&vb / "attn", config.n_head, config.n_embd, cache)?; let rms_2 = RmsNorm::new(&vb / "rms_2", config.n_embd)?; let mlp = Mlp::new(&vb / "mlp", config.n_embd)?; Ok(Self { @@ -348,7 +377,7 @@ struct Llama { } impl Llama { - fn new(vb: VarBuilder, config: &Config) -> Result { + fn new(vb: VarBuilder, cache: &Cache, config: &Config) -> Result { let lm_head = Linear::new_no_bias(&vb / "lm_head", config.n_embd, config.vocab_size)?; let wte = Embedding::new( &vb / "transformer" / "wte", @@ -356,7 +385,7 @@ impl Llama { config.n_embd, )?; let blocks = (0..config.n_layer) - .map(|i| Block::new(&vb / "transformer" / "h" / i, config)) + .map(|i| Block::new(&vb / "transformer" / "h" / i, cache, config)) .collect::>>()?; let ln_f = RmsNorm::new(&vb / "transformer" / "ln_f", config.n_embd)?; Ok(Self { @@ -453,13 +482,15 @@ fn main() -> Result<()> { println!("building the model"); let config = Config::config_7b(); - let llama = Llama::new(vb, &config)?; + let cache = Cache::new(&device); + let llama = Llama::new(vb, &cache, &config)?; println!("pre-computing the positional embeddings"); let freqs_cis = precompute_freqs_cis(&config, &device)?; println!("starting the inference loop"); let mut new_tokens = vec![]; let mut rng = thread_rng(); + let start_gen = std::time::Instant::now(); for index in 0..args.sample_len { let ctxt = &tokens[tokens.len().saturating_sub(CONTEXT_SIZE)..]; let input = Tensor::new(ctxt, &device)?; @@ -477,8 +508,11 @@ fn main() -> Result<()> { tokenizer.decode(vec![next_token], true).map_err(E::msg)? ); } + let dt = start_gen.elapsed(); println!( - "----\n{}\n----", + "{} tokens generated ({} token/s)\n----\n{}\n----", + args.sample_len, + args.sample_len as f64 / dt.as_secs_f64(), tokenizer.decode(new_tokens, true).map_err(E::msg)? ); Ok(())