mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Cache the causal mask in llama.
This commit is contained in:
@ -15,6 +15,8 @@ use anyhow::{Error as E, Result};
|
|||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
|
|
||||||
use candle::{DType, Device, Tensor};
|
use candle::{DType, Device, Tensor};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
mod var_store;
|
mod var_store;
|
||||||
use var_store::VarBuilder;
|
use var_store::VarBuilder;
|
||||||
@ -231,15 +233,50 @@ fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor>
|
|||||||
Ok(m)
|
Ok(m)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
struct Cache {
|
||||||
|
masks: Arc<Mutex<HashMap<usize, Tensor>>>,
|
||||||
|
device: Device,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Cache {
|
||||||
|
fn new(device: &Device) -> Self {
|
||||||
|
Self {
|
||||||
|
masks: Arc::new(Mutex::new(HashMap::new())),
|
||||||
|
device: device.clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn mask(&self, t: usize) -> Result<Tensor> {
|
||||||
|
let mut masks = self.masks.lock().unwrap();
|
||||||
|
if let Some(mask) = masks.get(&t) {
|
||||||
|
Ok(mask.clone())
|
||||||
|
} else {
|
||||||
|
// TODO: If we support bool or u8 tensors, this would be better.
|
||||||
|
let mask: Vec<_> = (0..t)
|
||||||
|
.flat_map(|i| (0..t).map(move |j| u32::from(j > i)))
|
||||||
|
.collect();
|
||||||
|
// Once lower_triangle is available, use the following:
|
||||||
|
//let mask = Tensor::new(1u32, &device)?
|
||||||
|
// .broadcast_as(&[t, t])?
|
||||||
|
// .lower_triangle()?
|
||||||
|
let mask = Tensor::from_slice(&mask, (t, t), &self.device)?;
|
||||||
|
masks.insert(t, mask.clone());
|
||||||
|
Ok(mask)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
struct CausalSelfAttention {
|
struct CausalSelfAttention {
|
||||||
c_attn: Linear,
|
c_attn: Linear,
|
||||||
c_proj: Linear,
|
c_proj: Linear,
|
||||||
n_head: usize,
|
n_head: usize,
|
||||||
n_embd: usize,
|
n_embd: usize,
|
||||||
|
cache: Cache,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CausalSelfAttention {
|
impl CausalSelfAttention {
|
||||||
fn new(vb: VarBuilder, n_head: usize, n_embd: usize) -> Result<Self> {
|
fn new(vb: VarBuilder, n_head: usize, n_embd: usize, cache: &Cache) -> Result<Self> {
|
||||||
let c_attn = Linear::new_no_bias(&vb / "c_attn", n_embd, 3 * n_embd)?;
|
let c_attn = Linear::new_no_bias(&vb / "c_attn", n_embd, 3 * n_embd)?;
|
||||||
let c_proj = Linear::new_no_bias(&vb / "c_proj", n_embd, n_embd)?;
|
let c_proj = Linear::new_no_bias(&vb / "c_proj", n_embd, n_embd)?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
@ -247,6 +284,7 @@ impl CausalSelfAttention {
|
|||||||
c_proj,
|
c_proj,
|
||||||
n_head,
|
n_head,
|
||||||
n_embd,
|
n_embd,
|
||||||
|
cache: cache.clone(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -292,16 +330,7 @@ impl CausalSelfAttention {
|
|||||||
let k = self.apply_rotary_emb(&k, freqs_cis)?;
|
let k = self.apply_rotary_emb(&k, freqs_cis)?;
|
||||||
let k_shape = k.shape();
|
let k_shape = k.shape();
|
||||||
let att = (q.matmul(&k.t()?)? / (*k_shape.dims().last().unwrap() as f64).sqrt())?;
|
let att = (q.matmul(&k.t()?)? / (*k_shape.dims().last().unwrap() as f64).sqrt())?;
|
||||||
let device = x.device();
|
let mask = self.cache.mask(t)?.broadcast_as(att.shape())?;
|
||||||
// TODO: If we support bool or u8 tensors, this would be better.
|
|
||||||
let mask: Vec<_> = (0..t)
|
|
||||||
.flat_map(|i| (0..t).map(move |j| u32::from(j > i)))
|
|
||||||
.collect();
|
|
||||||
// Once lower_triangle is available, use the following:
|
|
||||||
//let mask = Tensor::new(1u32, &device)?
|
|
||||||
// .broadcast_as(&[t, t])?
|
|
||||||
// .lower_triangle()?
|
|
||||||
let mask = Tensor::from_slice(&mask, (t, t), &device)?.broadcast_as(att.shape())?;
|
|
||||||
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
|
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
|
||||||
let att = att.softmax(att.rank() - 1)?;
|
let att = att.softmax(att.rank() - 1)?;
|
||||||
// Convert to contiguous as matmul doesn't support strided vs for now.
|
// Convert to contiguous as matmul doesn't support strided vs for now.
|
||||||
@ -320,9 +349,9 @@ struct Block {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Block {
|
impl Block {
|
||||||
fn new(vb: VarBuilder, config: &Config) -> Result<Self> {
|
fn new(vb: VarBuilder, cache: &Cache, config: &Config) -> Result<Self> {
|
||||||
let rms_1 = RmsNorm::new(&vb / "rms_1", config.n_embd)?;
|
let rms_1 = RmsNorm::new(&vb / "rms_1", config.n_embd)?;
|
||||||
let attn = CausalSelfAttention::new(&vb / "attn", config.n_head, config.n_embd)?;
|
let attn = CausalSelfAttention::new(&vb / "attn", config.n_head, config.n_embd, cache)?;
|
||||||
let rms_2 = RmsNorm::new(&vb / "rms_2", config.n_embd)?;
|
let rms_2 = RmsNorm::new(&vb / "rms_2", config.n_embd)?;
|
||||||
let mlp = Mlp::new(&vb / "mlp", config.n_embd)?;
|
let mlp = Mlp::new(&vb / "mlp", config.n_embd)?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
@ -348,7 +377,7 @@ struct Llama {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Llama {
|
impl Llama {
|
||||||
fn new(vb: VarBuilder, config: &Config) -> Result<Self> {
|
fn new(vb: VarBuilder, cache: &Cache, config: &Config) -> Result<Self> {
|
||||||
let lm_head = Linear::new_no_bias(&vb / "lm_head", config.n_embd, config.vocab_size)?;
|
let lm_head = Linear::new_no_bias(&vb / "lm_head", config.n_embd, config.vocab_size)?;
|
||||||
let wte = Embedding::new(
|
let wte = Embedding::new(
|
||||||
&vb / "transformer" / "wte",
|
&vb / "transformer" / "wte",
|
||||||
@ -356,7 +385,7 @@ impl Llama {
|
|||||||
config.n_embd,
|
config.n_embd,
|
||||||
)?;
|
)?;
|
||||||
let blocks = (0..config.n_layer)
|
let blocks = (0..config.n_layer)
|
||||||
.map(|i| Block::new(&vb / "transformer" / "h" / i, config))
|
.map(|i| Block::new(&vb / "transformer" / "h" / i, cache, config))
|
||||||
.collect::<Result<Vec<_>>>()?;
|
.collect::<Result<Vec<_>>>()?;
|
||||||
let ln_f = RmsNorm::new(&vb / "transformer" / "ln_f", config.n_embd)?;
|
let ln_f = RmsNorm::new(&vb / "transformer" / "ln_f", config.n_embd)?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
@ -453,13 +482,15 @@ fn main() -> Result<()> {
|
|||||||
|
|
||||||
println!("building the model");
|
println!("building the model");
|
||||||
let config = Config::config_7b();
|
let config = Config::config_7b();
|
||||||
let llama = Llama::new(vb, &config)?;
|
let cache = Cache::new(&device);
|
||||||
|
let llama = Llama::new(vb, &cache, &config)?;
|
||||||
|
|
||||||
println!("pre-computing the positional embeddings");
|
println!("pre-computing the positional embeddings");
|
||||||
let freqs_cis = precompute_freqs_cis(&config, &device)?;
|
let freqs_cis = precompute_freqs_cis(&config, &device)?;
|
||||||
println!("starting the inference loop");
|
println!("starting the inference loop");
|
||||||
let mut new_tokens = vec![];
|
let mut new_tokens = vec![];
|
||||||
let mut rng = thread_rng();
|
let mut rng = thread_rng();
|
||||||
|
let start_gen = std::time::Instant::now();
|
||||||
for index in 0..args.sample_len {
|
for index in 0..args.sample_len {
|
||||||
let ctxt = &tokens[tokens.len().saturating_sub(CONTEXT_SIZE)..];
|
let ctxt = &tokens[tokens.len().saturating_sub(CONTEXT_SIZE)..];
|
||||||
let input = Tensor::new(ctxt, &device)?;
|
let input = Tensor::new(ctxt, &device)?;
|
||||||
@ -477,8 +508,11 @@ fn main() -> Result<()> {
|
|||||||
tokenizer.decode(vec![next_token], true).map_err(E::msg)?
|
tokenizer.decode(vec![next_token], true).map_err(E::msg)?
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
let dt = start_gen.elapsed();
|
||||||
println!(
|
println!(
|
||||||
"----\n{}\n----",
|
"{} tokens generated ({} token/s)\n----\n{}\n----",
|
||||||
|
args.sample_len,
|
||||||
|
args.sample_len as f64 / dt.as_secs_f64(),
|
||||||
tokenizer.decode(new_tokens, true).map_err(E::msg)?
|
tokenizer.decode(new_tokens, true).map_err(E::msg)?
|
||||||
);
|
);
|
||||||
Ok(())
|
Ok(())
|
||||||
|
Reference in New Issue
Block a user