Cache the causal mask in llama.

This commit is contained in:
laurent
2023-06-27 12:21:08 +01:00
parent 527a71fdad
commit 318503cd38

View File

@ -15,6 +15,8 @@ use anyhow::{Error as E, Result};
use clap::Parser;
use candle::{DType, Device, Tensor};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
mod var_store;
use var_store::VarBuilder;
@ -231,15 +233,50 @@ fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor>
Ok(m)
}
#[derive(Clone)]
struct Cache {
masks: Arc<Mutex<HashMap<usize, Tensor>>>,
device: Device,
}
impl Cache {
fn new(device: &Device) -> Self {
Self {
masks: Arc::new(Mutex::new(HashMap::new())),
device: device.clone(),
}
}
fn mask(&self, t: usize) -> Result<Tensor> {
let mut masks = self.masks.lock().unwrap();
if let Some(mask) = masks.get(&t) {
Ok(mask.clone())
} else {
// TODO: If we support bool or u8 tensors, this would be better.
let mask: Vec<_> = (0..t)
.flat_map(|i| (0..t).map(move |j| u32::from(j > i)))
.collect();
// Once lower_triangle is available, use the following:
//let mask = Tensor::new(1u32, &device)?
// .broadcast_as(&[t, t])?
// .lower_triangle()?
let mask = Tensor::from_slice(&mask, (t, t), &self.device)?;
masks.insert(t, mask.clone());
Ok(mask)
}
}
}
struct CausalSelfAttention {
c_attn: Linear,
c_proj: Linear,
n_head: usize,
n_embd: usize,
cache: Cache,
}
impl CausalSelfAttention {
fn new(vb: VarBuilder, n_head: usize, n_embd: usize) -> Result<Self> {
fn new(vb: VarBuilder, n_head: usize, n_embd: usize, cache: &Cache) -> Result<Self> {
let c_attn = Linear::new_no_bias(&vb / "c_attn", n_embd, 3 * n_embd)?;
let c_proj = Linear::new_no_bias(&vb / "c_proj", n_embd, n_embd)?;
Ok(Self {
@ -247,6 +284,7 @@ impl CausalSelfAttention {
c_proj,
n_head,
n_embd,
cache: cache.clone(),
})
}
@ -292,16 +330,7 @@ impl CausalSelfAttention {
let k = self.apply_rotary_emb(&k, freqs_cis)?;
let k_shape = k.shape();
let att = (q.matmul(&k.t()?)? / (*k_shape.dims().last().unwrap() as f64).sqrt())?;
let device = x.device();
// TODO: If we support bool or u8 tensors, this would be better.
let mask: Vec<_> = (0..t)
.flat_map(|i| (0..t).map(move |j| u32::from(j > i)))
.collect();
// Once lower_triangle is available, use the following:
//let mask = Tensor::new(1u32, &device)?
// .broadcast_as(&[t, t])?
// .lower_triangle()?
let mask = Tensor::from_slice(&mask, (t, t), &device)?.broadcast_as(att.shape())?;
let mask = self.cache.mask(t)?.broadcast_as(att.shape())?;
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
let att = att.softmax(att.rank() - 1)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
@ -320,9 +349,9 @@ struct Block {
}
impl Block {
fn new(vb: VarBuilder, config: &Config) -> Result<Self> {
fn new(vb: VarBuilder, cache: &Cache, config: &Config) -> Result<Self> {
let rms_1 = RmsNorm::new(&vb / "rms_1", config.n_embd)?;
let attn = CausalSelfAttention::new(&vb / "attn", config.n_head, config.n_embd)?;
let attn = CausalSelfAttention::new(&vb / "attn", config.n_head, config.n_embd, cache)?;
let rms_2 = RmsNorm::new(&vb / "rms_2", config.n_embd)?;
let mlp = Mlp::new(&vb / "mlp", config.n_embd)?;
Ok(Self {
@ -348,7 +377,7 @@ struct Llama {
}
impl Llama {
fn new(vb: VarBuilder, config: &Config) -> Result<Self> {
fn new(vb: VarBuilder, cache: &Cache, config: &Config) -> Result<Self> {
let lm_head = Linear::new_no_bias(&vb / "lm_head", config.n_embd, config.vocab_size)?;
let wte = Embedding::new(
&vb / "transformer" / "wte",
@ -356,7 +385,7 @@ impl Llama {
config.n_embd,
)?;
let blocks = (0..config.n_layer)
.map(|i| Block::new(&vb / "transformer" / "h" / i, config))
.map(|i| Block::new(&vb / "transformer" / "h" / i, cache, config))
.collect::<Result<Vec<_>>>()?;
let ln_f = RmsNorm::new(&vb / "transformer" / "ln_f", config.n_embd)?;
Ok(Self {
@ -453,13 +482,15 @@ fn main() -> Result<()> {
println!("building the model");
let config = Config::config_7b();
let llama = Llama::new(vb, &config)?;
let cache = Cache::new(&device);
let llama = Llama::new(vb, &cache, &config)?;
println!("pre-computing the positional embeddings");
let freqs_cis = precompute_freqs_cis(&config, &device)?;
println!("starting the inference loop");
let mut new_tokens = vec![];
let mut rng = thread_rng();
let start_gen = std::time::Instant::now();
for index in 0..args.sample_len {
let ctxt = &tokens[tokens.len().saturating_sub(CONTEXT_SIZE)..];
let input = Tensor::new(ctxt, &device)?;
@ -477,8 +508,11 @@ fn main() -> Result<()> {
tokenizer.decode(vec![next_token], true).map_err(E::msg)?
);
}
let dt = start_gen.elapsed();
println!(
"----\n{}\n----",
"{} tokens generated ({} token/s)\n----\n{}\n----",
args.sample_len,
args.sample_len as f64 / dt.as_secs_f64(),
tokenizer.decode(new_tokens, true).map_err(E::msg)?
);
Ok(())