mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Cache the causal mask in llama.
This commit is contained in:
@ -15,6 +15,8 @@ use anyhow::{Error as E, Result};
|
||||
use clap::Parser;
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
mod var_store;
|
||||
use var_store::VarBuilder;
|
||||
@ -231,15 +233,50 @@ fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor>
|
||||
Ok(m)
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct Cache {
|
||||
masks: Arc<Mutex<HashMap<usize, Tensor>>>,
|
||||
device: Device,
|
||||
}
|
||||
|
||||
impl Cache {
|
||||
fn new(device: &Device) -> Self {
|
||||
Self {
|
||||
masks: Arc::new(Mutex::new(HashMap::new())),
|
||||
device: device.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn mask(&self, t: usize) -> Result<Tensor> {
|
||||
let mut masks = self.masks.lock().unwrap();
|
||||
if let Some(mask) = masks.get(&t) {
|
||||
Ok(mask.clone())
|
||||
} else {
|
||||
// TODO: If we support bool or u8 tensors, this would be better.
|
||||
let mask: Vec<_> = (0..t)
|
||||
.flat_map(|i| (0..t).map(move |j| u32::from(j > i)))
|
||||
.collect();
|
||||
// Once lower_triangle is available, use the following:
|
||||
//let mask = Tensor::new(1u32, &device)?
|
||||
// .broadcast_as(&[t, t])?
|
||||
// .lower_triangle()?
|
||||
let mask = Tensor::from_slice(&mask, (t, t), &self.device)?;
|
||||
masks.insert(t, mask.clone());
|
||||
Ok(mask)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct CausalSelfAttention {
|
||||
c_attn: Linear,
|
||||
c_proj: Linear,
|
||||
n_head: usize,
|
||||
n_embd: usize,
|
||||
cache: Cache,
|
||||
}
|
||||
|
||||
impl CausalSelfAttention {
|
||||
fn new(vb: VarBuilder, n_head: usize, n_embd: usize) -> Result<Self> {
|
||||
fn new(vb: VarBuilder, n_head: usize, n_embd: usize, cache: &Cache) -> Result<Self> {
|
||||
let c_attn = Linear::new_no_bias(&vb / "c_attn", n_embd, 3 * n_embd)?;
|
||||
let c_proj = Linear::new_no_bias(&vb / "c_proj", n_embd, n_embd)?;
|
||||
Ok(Self {
|
||||
@ -247,6 +284,7 @@ impl CausalSelfAttention {
|
||||
c_proj,
|
||||
n_head,
|
||||
n_embd,
|
||||
cache: cache.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
@ -292,16 +330,7 @@ impl CausalSelfAttention {
|
||||
let k = self.apply_rotary_emb(&k, freqs_cis)?;
|
||||
let k_shape = k.shape();
|
||||
let att = (q.matmul(&k.t()?)? / (*k_shape.dims().last().unwrap() as f64).sqrt())?;
|
||||
let device = x.device();
|
||||
// TODO: If we support bool or u8 tensors, this would be better.
|
||||
let mask: Vec<_> = (0..t)
|
||||
.flat_map(|i| (0..t).map(move |j| u32::from(j > i)))
|
||||
.collect();
|
||||
// Once lower_triangle is available, use the following:
|
||||
//let mask = Tensor::new(1u32, &device)?
|
||||
// .broadcast_as(&[t, t])?
|
||||
// .lower_triangle()?
|
||||
let mask = Tensor::from_slice(&mask, (t, t), &device)?.broadcast_as(att.shape())?;
|
||||
let mask = self.cache.mask(t)?.broadcast_as(att.shape())?;
|
||||
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
|
||||
let att = att.softmax(att.rank() - 1)?;
|
||||
// Convert to contiguous as matmul doesn't support strided vs for now.
|
||||
@ -320,9 +349,9 @@ struct Block {
|
||||
}
|
||||
|
||||
impl Block {
|
||||
fn new(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||
fn new(vb: VarBuilder, cache: &Cache, config: &Config) -> Result<Self> {
|
||||
let rms_1 = RmsNorm::new(&vb / "rms_1", config.n_embd)?;
|
||||
let attn = CausalSelfAttention::new(&vb / "attn", config.n_head, config.n_embd)?;
|
||||
let attn = CausalSelfAttention::new(&vb / "attn", config.n_head, config.n_embd, cache)?;
|
||||
let rms_2 = RmsNorm::new(&vb / "rms_2", config.n_embd)?;
|
||||
let mlp = Mlp::new(&vb / "mlp", config.n_embd)?;
|
||||
Ok(Self {
|
||||
@ -348,7 +377,7 @@ struct Llama {
|
||||
}
|
||||
|
||||
impl Llama {
|
||||
fn new(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||
fn new(vb: VarBuilder, cache: &Cache, config: &Config) -> Result<Self> {
|
||||
let lm_head = Linear::new_no_bias(&vb / "lm_head", config.n_embd, config.vocab_size)?;
|
||||
let wte = Embedding::new(
|
||||
&vb / "transformer" / "wte",
|
||||
@ -356,7 +385,7 @@ impl Llama {
|
||||
config.n_embd,
|
||||
)?;
|
||||
let blocks = (0..config.n_layer)
|
||||
.map(|i| Block::new(&vb / "transformer" / "h" / i, config))
|
||||
.map(|i| Block::new(&vb / "transformer" / "h" / i, cache, config))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let ln_f = RmsNorm::new(&vb / "transformer" / "ln_f", config.n_embd)?;
|
||||
Ok(Self {
|
||||
@ -453,13 +482,15 @@ fn main() -> Result<()> {
|
||||
|
||||
println!("building the model");
|
||||
let config = Config::config_7b();
|
||||
let llama = Llama::new(vb, &config)?;
|
||||
let cache = Cache::new(&device);
|
||||
let llama = Llama::new(vb, &cache, &config)?;
|
||||
|
||||
println!("pre-computing the positional embeddings");
|
||||
let freqs_cis = precompute_freqs_cis(&config, &device)?;
|
||||
println!("starting the inference loop");
|
||||
let mut new_tokens = vec![];
|
||||
let mut rng = thread_rng();
|
||||
let start_gen = std::time::Instant::now();
|
||||
for index in 0..args.sample_len {
|
||||
let ctxt = &tokens[tokens.len().saturating_sub(CONTEXT_SIZE)..];
|
||||
let input = Tensor::new(ctxt, &device)?;
|
||||
@ -477,8 +508,11 @@ fn main() -> Result<()> {
|
||||
tokenizer.decode(vec![next_token], true).map_err(E::msg)?
|
||||
);
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
println!(
|
||||
"----\n{}\n----",
|
||||
"{} tokens generated ({} token/s)\n----\n{}\n----",
|
||||
args.sample_len,
|
||||
args.sample_len as f64 / dt.as_secs_f64(),
|
||||
tokenizer.decode(new_tokens, true).map_err(E::msg)?
|
||||
);
|
||||
Ok(())
|
||||
|
Reference in New Issue
Block a user