Explicit caching in llama2.c.

This commit is contained in:
laurent
2024-02-22 10:22:03 +01:00
parent c753f72c85
commit 544018b6d0
4 changed files with 99 additions and 75 deletions

View File

@ -19,7 +19,7 @@ use candle_transformers::generation::LogitsProcessor;
use std::io::Write;
use tokenizers::Tokenizer;
use model::{Config, Llama};
use model::{Cache, Config, Llama};
use qmodel::QLlama;
use weights::TransformerWeights;
@ -160,10 +160,10 @@ enum Model {
}
impl Model {
fn forward(&self, xs: &Tensor, pos: usize) -> anyhow::Result<Tensor> {
fn forward(&self, xs: &Tensor, pos: usize, cache: &mut Cache) -> anyhow::Result<Tensor> {
match self {
Self::Llama(l) => Ok(l.forward(xs, pos)?),
Self::QLlama(l) => Ok(l.forward(xs, pos)?),
Self::Llama(l) => Ok(l.forward(xs, pos, cache)?),
Self::QLlama(l) => Ok(l.forward(xs, pos, cache)?),
}
}
}
@ -188,8 +188,8 @@ fn run_eval(args: &EvaluationCmd, common_args: &Args) -> Result<()> {
let config = Config::from_reader(&mut file)?;
let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
let vb = weights.var_builder(&config, &device)?;
let cache = model::Cache::new(false, &config, vb.pp("rot"))?;
let model = Llama::load(vb, &cache, config)?;
let mut cache = Cache::new(false, &config, vb.pp("rot"))?;
let model = Llama::load(vb, config)?;
let tokens = match &args.pretokenized_dir {
None => {
@ -235,7 +235,7 @@ fn run_eval(args: &EvaluationCmd, common_args: &Args) -> Result<()> {
let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
for inp_tgt in batch_iter {
let (inp, tgt) = inp_tgt?;
let logits = model.forward(&inp, 0)?;
let logits = model.forward(&inp, 0, &mut cache)?;
let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?;
println!("{}", loss.to_vec0::<f32>()?);
}
@ -261,7 +261,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
let is_safetensors = config_path
.extension()
.map_or(false, |v| v == "safetensors");
let (model, config) = if is_gguf {
let (model, config, mut cache) = if is_gguf {
let vb = qmodel::VarBuilder::from_gguf(config_path, &device)?;
let (_vocab_size, dim) = vb
.get_no_shape("model.embed_tokens.weight")?
@ -298,15 +298,15 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
&device,
);
let cache = model::Cache::new(true, &config, fake_vb)?;
let model = Model::QLlama(QLlama::load(vb, &cache, config.clone())?);
(model, config)
let model = Model::QLlama(QLlama::load(vb, config.clone())?);
(model, config, cache)
} else if is_safetensors {
let config = Config::tiny_15m();
let tensors = candle::safetensors::load(config_path, &device)?;
let vb = candle_nn::VarBuilder::from_tensors(tensors, candle::DType::F32, &device);
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
let model = Model::Llama(Llama::load(vb, &cache, config.clone())?);
(model, config)
let model = Model::Llama(Llama::load(vb, config.clone())?);
(model, config, cache)
} else {
let mut file = std::fs::File::open(config_path)?;
let config = Config::from_reader(&mut file)?;
@ -314,8 +314,8 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
let vb = weights.var_builder(&config, &device)?;
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
let model = Model::Llama(Llama::load(vb, &cache, config.clone())?);
(model, config)
let model = Model::Llama(Llama::load(vb, config.clone())?);
(model, config, cache)
};
println!("starting the inference loop");
@ -338,7 +338,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
let context_size = if index > 0 { 1 } else { tokens.len() };
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
let logits = model.forward(&input, index_pos)?;
let logits = model.forward(&input, index_pos, &mut cache)?;
let logits = logits.i((0, logits.dim(1)? - 1))?;
let logits = if common_args.repeat_penalty == 1. || tokens.is_empty() {
logits

View File

@ -8,6 +8,7 @@ fn valid_loss(
model: &Llama,
args: &crate::TrainingCmd,
device: &Device,
cache: &mut Cache,
) -> Result<f64> {
let iter = DatasetRandomIter::new(dataset, true, model.config.seq_len, device.clone());
let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
@ -15,7 +16,7 @@ fn valid_loss(
let mut cnt = 0usize;
for inp_tgt in batch_iter.take(50) {
let (inp, tgt) = inp_tgt?;
let logits = model.forward(&inp, 0)?;
let logits = model.forward(&inp, 0, cache)?;
let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?;
sum_ce += loss.to_vec0::<f32>()? as f64;
cnt += 1;
@ -37,8 +38,8 @@ pub fn run(args: &crate::TrainingCmd, common_args: &crate::Args) -> Result<()> {
let iter = DatasetRandomIter::new(&dataset, false, config.seq_len, device.clone());
let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
let cache = Cache::new(false, &config, vb.pp("rot"))?;
let model = Llama::load(vb, &cache, config)?;
let mut cache = Cache::new(false, &config, vb.pp("rot"))?;
let model = Llama::load(vb, config)?;
let params = candle_nn::ParamsAdamW {
lr: args.learning_rate,
..Default::default()
@ -46,14 +47,14 @@ pub fn run(args: &crate::TrainingCmd, common_args: &crate::Args) -> Result<()> {
let mut opt = candle_nn::AdamW::new(varmap.all_vars(), params)?;
for (batch_index, batch) in batch_iter.enumerate() {
let (inp, tgt) = batch?;
let logits = model.forward(&inp, 0)?;
let logits = model.forward(&inp, 0, &mut cache)?;
let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?;
opt.backward_step(&loss)?;
if batch_index > 0 && batch_index % 100 == 0 {
// TODO: Add a way to deactivate the backprop graph tracking when computing the
// validation loss.
let loss = valid_loss(&dataset, &model, args, &device)?;
let loss = valid_loss(&dataset, &model, args, &device, &mut cache)?;
println!("{batch_index} {loss}");
}
if batch_index > 0 && batch_index % 1000 == 0 {

View File

@ -2,7 +2,6 @@ use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::linear_no_bias as linear;
use candle_nn::{embedding, rms_norm, Embedding, Linear, Module, RmsNorm, VarBuilder};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
#[derive(Debug, Clone)]
pub struct Config {
@ -70,12 +69,11 @@ impl Config {
}
}
#[derive(Clone)]
#[derive(Debug, Clone)]
pub struct Cache {
masks: Arc<Mutex<HashMap<usize, Tensor>>>,
masks: HashMap<usize, Tensor>,
pub use_kv_cache: bool,
#[allow(clippy::type_complexity)]
pub kvs: Arc<Mutex<Vec<Option<(Tensor, Tensor)>>>>,
pub kvs: Vec<Option<(Tensor, Tensor)>>,
pub cos: Tensor,
pub sin: Tensor,
device: Device,
@ -105,25 +103,24 @@ impl Cache {
let cos = freq_cis_real.reshape((cfg.seq_len, cfg.head_size() / 2, 1))?;
let sin = freq_cis_imag.reshape((cfg.seq_len, cfg.head_size() / 2, 1))?;
Ok(Self {
masks: Arc::new(Mutex::new(HashMap::new())),
masks: HashMap::new(),
use_kv_cache,
kvs: Arc::new(Mutex::new(vec![None; cfg.n_layers])),
kvs: vec![None; cfg.n_layers],
cos,
sin,
device: vb.device().clone(),
})
}
pub fn mask(&self, t: usize) -> Result<Tensor> {
let mut masks = self.masks.lock().unwrap();
if let Some(mask) = masks.get(&t) {
pub fn mask(&mut self, t: usize) -> Result<Tensor> {
if let Some(mask) = self.masks.get(&t) {
Ok(mask.clone())
} else {
let mask: Vec<_> = (0..t)
.flat_map(|i| (0..t).map(move |j| u8::from(j > i)))
.collect();
let mask = Tensor::from_slice(&mask, (t, t), &self.device)?;
masks.insert(t, mask.clone());
self.masks.insert(t, mask.clone());
Ok(mask)
}
}
@ -133,6 +130,7 @@ fn silu(xs: &Tensor) -> Result<Tensor> {
xs / (xs.neg()?.exp()? + 1.0)?
}
#[derive(Debug, Clone)]
struct CausalSelfAttention {
q_proj: Linear,
k_proj: Linear,
@ -141,14 +139,13 @@ struct CausalSelfAttention {
n_head: usize,
n_key_value_head: usize,
head_dim: usize,
cache: Cache,
}
impl CausalSelfAttention {
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize, cache: &Cache) -> Result<Tensor> {
let (b_sz, seq_len, h, n_embd) = x.dims4()?;
let cos = self.cache.cos.i(index_pos..index_pos + seq_len)?;
let sin = self.cache.sin.i(index_pos..index_pos + seq_len)?;
let cos = cache.cos.i(index_pos..index_pos + seq_len)?;
let sin = cache.sin.i(index_pos..index_pos + seq_len)?;
let cos = cos.unsqueeze(1)?;
let sin = sin.unsqueeze(1)?;
let cos = cos.broadcast_as((b_sz, seq_len, 1, n_embd / 2, 1))?;
@ -162,7 +159,13 @@ impl CausalSelfAttention {
Ok(rope)
}
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
fn forward(
&self,
x: &Tensor,
index_pos: usize,
block_idx: usize,
cache: &mut Cache,
) -> Result<Tensor> {
let (b_sz, seq_len, n_embd) = x.dims3()?;
let q = self.q_proj.forward(x)?;
let k = self.k_proj.forward(x)?;
@ -172,16 +175,15 @@ impl CausalSelfAttention {
let k = k.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?;
let mut v = v.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?;
let q = self.apply_rotary_emb(&q, index_pos)?;
let mut k = self.apply_rotary_emb(&k, index_pos)?;
let q = self.apply_rotary_emb(&q, index_pos, cache)?;
let mut k = self.apply_rotary_emb(&k, index_pos, cache)?;
if self.cache.use_kv_cache {
let mut cache = self.cache.kvs.lock().unwrap();
if let Some((cache_k, cache_v)) = &cache[block_idx] {
if cache.use_kv_cache {
if let Some((cache_k, cache_v)) = &cache.kvs[block_idx] {
k = Tensor::cat(&[cache_k, &k], 1)?.contiguous()?;
v = Tensor::cat(&[cache_v, &v], 1)?.contiguous()?;
}
cache[block_idx] = Some((k.clone(), v.clone()))
cache.kvs[block_idx] = Some((k.clone(), v.clone()))
}
let k = self.repeat_kv(k)?;
@ -192,7 +194,7 @@ impl CausalSelfAttention {
let v = v.transpose(1, 2)?.contiguous()?;
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?;
let mask = cache.mask(seq_len)?.broadcast_as(att.shape())?;
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
let att = candle_nn::ops::softmax(&att, D::Minus1)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
@ -216,7 +218,7 @@ impl CausalSelfAttention {
}
}
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let size_in = cfg.dim;
let size_q = (cfg.dim / cfg.n_heads) * cfg.n_heads;
let size_kv = (cfg.dim / cfg.n_heads) * cfg.n_kv_heads;
@ -232,7 +234,6 @@ impl CausalSelfAttention {
n_head: cfg.n_heads,
n_key_value_head: cfg.n_kv_heads,
head_dim: cfg.dim / cfg.n_heads,
cache: cache.clone(),
})
}
}
@ -244,6 +245,7 @@ fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor>
Ok(m)
}
#[derive(Debug, Clone)]
struct Mlp {
c_fc1: Linear,
c_fc2: Linear,
@ -274,6 +276,7 @@ impl Mlp {
}
}
#[derive(Debug, Clone)]
struct Block {
rms_1: RmsNorm,
attn: CausalSelfAttention,
@ -291,17 +294,23 @@ impl Block {
}
}
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
fn forward(
&self,
x: &Tensor,
index_pos: usize,
block_idx: usize,
cache: &mut Cache,
) -> Result<Tensor> {
let residual = x;
let x = self.rms_1.forward(x)?;
let x = (self.attn.forward(&x, index_pos, block_idx)? + residual)?;
let x = (self.attn.forward(&x, index_pos, block_idx, cache)? + residual)?;
let residual = &x;
let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?;
Ok(x)
}
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg)?;
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cfg)?;
let mlp = Mlp::load(vb.pp("mlp"), cfg)?;
let input_layernorm = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("input_layernorm"))?;
let post_attention_layernorm =
@ -315,6 +324,7 @@ impl Block {
}
}
#[derive(Debug, Clone)]
pub struct Llama {
wte: Embedding,
blocks: Vec<Block>,
@ -324,23 +334,23 @@ pub struct Llama {
}
impl Llama {
pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
pub fn forward(&self, x: &Tensor, index_pos: usize, cache: &mut Cache) -> Result<Tensor> {
let (_b_sz, _seq_len) = x.dims2()?;
let mut x = self.wte.forward(x)?;
for (block_idx, block) in self.blocks.iter().enumerate() {
x = block.forward(&x, index_pos, block_idx)?;
x = block.forward(&x, index_pos, block_idx, cache)?;
}
let x = self.ln_f.forward(&x)?;
let logits = self.lm_head.forward(&x)?;
logits.to_dtype(DType::F32)
}
pub fn load(vb: VarBuilder, cache: &Cache, cfg: Config) -> Result<Self> {
pub fn load(vb: VarBuilder, cfg: Config) -> Result<Self> {
let wte = embedding(cfg.vocab_size, cfg.dim, vb.pp("model.embed_tokens"))?;
let lm_head = linear(cfg.dim, cfg.vocab_size, vb.pp("lm_head"))?;
let ln_f = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?;
let blocks: Vec<_> = (0..cfg.n_layers)
.map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), cache, &cfg).unwrap())
.map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), &cfg).unwrap())
.collect();
Ok(Self {
wte,

View File

@ -7,6 +7,7 @@ fn silu(xs: &Tensor) -> Result<Tensor> {
xs / (xs.neg()?.exp()? + 1.0)?
}
#[derive(Debug, Clone)]
struct CausalSelfAttention {
q_proj: Linear,
k_proj: Linear,
@ -15,14 +16,13 @@ struct CausalSelfAttention {
n_head: usize,
n_key_value_head: usize,
head_dim: usize,
cache: Cache,
}
impl CausalSelfAttention {
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize, cache: &Cache) -> Result<Tensor> {
let (b_sz, seq_len, h, n_embd) = x.dims4()?;
let cos = self.cache.cos.i(index_pos..index_pos + seq_len)?;
let sin = self.cache.sin.i(index_pos..index_pos + seq_len)?;
let cos = cache.cos.i(index_pos..index_pos + seq_len)?;
let sin = cache.sin.i(index_pos..index_pos + seq_len)?;
let cos = cos.unsqueeze(1)?;
let sin = sin.unsqueeze(1)?;
let cos = cos.broadcast_as((b_sz, seq_len, 1, n_embd / 2, 1))?;
@ -36,7 +36,13 @@ impl CausalSelfAttention {
Ok(rope)
}
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
fn forward(
&self,
x: &Tensor,
index_pos: usize,
block_idx: usize,
cache: &mut Cache,
) -> Result<Tensor> {
let (b_sz, seq_len, n_embd) = x.dims3()?;
let q = self.q_proj.forward(x)?;
let k = self.k_proj.forward(x)?;
@ -46,16 +52,15 @@ impl CausalSelfAttention {
let k = k.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?;
let mut v = v.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?;
let q = self.apply_rotary_emb(&q, index_pos)?;
let mut k = self.apply_rotary_emb(&k, index_pos)?;
let q = self.apply_rotary_emb(&q, index_pos, cache)?;
let mut k = self.apply_rotary_emb(&k, index_pos, cache)?;
if self.cache.use_kv_cache {
let mut cache = self.cache.kvs.lock().unwrap();
if let Some((cache_k, cache_v)) = &cache[block_idx] {
if cache.use_kv_cache {
if let Some((cache_k, cache_v)) = &cache.kvs[block_idx] {
k = Tensor::cat(&[cache_k, &k], 1)?.contiguous()?;
v = Tensor::cat(&[cache_v, &v], 1)?.contiguous()?;
}
cache[block_idx] = Some((k.clone(), v.clone()))
cache.kvs[block_idx] = Some((k.clone(), v.clone()))
}
let k = self.repeat_kv(k)?;
@ -66,7 +71,7 @@ impl CausalSelfAttention {
let v = v.transpose(1, 2)?.contiguous()?;
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?;
let mask = cache.mask(seq_len)?.broadcast_as(att.shape())?;
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
let att = candle_nn::ops::softmax(&att, D::Minus1)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
@ -90,7 +95,7 @@ impl CausalSelfAttention {
}
}
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let size_in = cfg.dim;
let size_q = (cfg.dim / cfg.n_heads) * cfg.n_heads;
let size_kv = (cfg.dim / cfg.n_heads) * cfg.n_kv_heads;
@ -106,7 +111,6 @@ impl CausalSelfAttention {
n_head: cfg.n_heads,
n_key_value_head: cfg.n_kv_heads,
head_dim: cfg.dim / cfg.n_heads,
cache: cache.clone(),
})
}
}
@ -118,6 +122,7 @@ fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor>
Ok(m)
}
#[derive(Debug, Clone)]
struct Mlp {
c_fc1: Linear,
c_fc2: Linear,
@ -148,6 +153,7 @@ impl Mlp {
}
}
#[derive(Debug, Clone)]
struct Block {
rms_1: RmsNorm,
attn: CausalSelfAttention,
@ -165,17 +171,23 @@ impl Block {
}
}
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
fn forward(
&self,
x: &Tensor,
index_pos: usize,
block_idx: usize,
cache: &mut Cache,
) -> Result<Tensor> {
let residual = x;
let x = self.rms_1.forward(x)?;
let x = (self.attn.forward(&x, index_pos, block_idx)? + residual)?;
let x = (self.attn.forward(&x, index_pos, block_idx, cache)? + residual)?;
let residual = &x;
let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?;
Ok(x)
}
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg)?;
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cfg)?;
let mlp = Mlp::load(vb.pp("mlp"), cfg)?;
let input_layernorm = RmsNorm::new(cfg.dim, cfg.norm_eps, vb.pp("input_layernorm"))?;
let post_attention_layernorm =
@ -189,6 +201,7 @@ impl Block {
}
}
#[derive(Debug, Clone)]
pub struct QLlama {
wte: Embedding,
blocks: Vec<Block>,
@ -198,23 +211,23 @@ pub struct QLlama {
}
impl QLlama {
pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
pub fn forward(&self, x: &Tensor, index_pos: usize, cache: &mut Cache) -> Result<Tensor> {
let (_b_sz, _seq_len) = x.dims2()?;
let mut x = self.wte.forward(x)?;
for (block_idx, block) in self.blocks.iter().enumerate() {
x = block.forward(&x, index_pos, block_idx)?;
x = block.forward(&x, index_pos, block_idx, cache)?;
}
let x = self.ln_f.forward(&x)?;
let logits = self.lm_head.forward(&x)?;
logits.to_dtype(DType::F32)
}
pub fn load(vb: VarBuilder, cache: &Cache, cfg: Config) -> Result<Self> {
pub fn load(vb: VarBuilder, cfg: Config) -> Result<Self> {
let wte = Embedding::new(cfg.vocab_size, cfg.dim, vb.pp("model.embed_tokens"))?;
let lm_head = linear(cfg.dim, cfg.vocab_size, vb.pp("lm_head"))?;
let ln_f = RmsNorm::new(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?;
let blocks: Vec<_> = (0..cfg.n_layers)
.map(|i| Block::load(vb.pp(format!("model.layers.{i}")), cache, &cfg).unwrap())
.map(|i| Block::load(vb.pp(format!("model.layers.{i}")), &cfg).unwrap())
.collect();
Ok(Self {
wte,