Explicit caching in llama2.c.

This commit is contained in:
laurent
2024-02-22 10:22:03 +01:00
parent c753f72c85
commit 544018b6d0
4 changed files with 99 additions and 75 deletions

View File

@ -19,7 +19,7 @@ use candle_transformers::generation::LogitsProcessor;
use std::io::Write; use std::io::Write;
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
use model::{Config, Llama}; use model::{Cache, Config, Llama};
use qmodel::QLlama; use qmodel::QLlama;
use weights::TransformerWeights; use weights::TransformerWeights;
@ -160,10 +160,10 @@ enum Model {
} }
impl Model { impl Model {
fn forward(&self, xs: &Tensor, pos: usize) -> anyhow::Result<Tensor> { fn forward(&self, xs: &Tensor, pos: usize, cache: &mut Cache) -> anyhow::Result<Tensor> {
match self { match self {
Self::Llama(l) => Ok(l.forward(xs, pos)?), Self::Llama(l) => Ok(l.forward(xs, pos, cache)?),
Self::QLlama(l) => Ok(l.forward(xs, pos)?), Self::QLlama(l) => Ok(l.forward(xs, pos, cache)?),
} }
} }
} }
@ -188,8 +188,8 @@ fn run_eval(args: &EvaluationCmd, common_args: &Args) -> Result<()> {
let config = Config::from_reader(&mut file)?; let config = Config::from_reader(&mut file)?;
let weights = TransformerWeights::from_reader(&mut file, &config, &device)?; let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
let vb = weights.var_builder(&config, &device)?; let vb = weights.var_builder(&config, &device)?;
let cache = model::Cache::new(false, &config, vb.pp("rot"))?; let mut cache = Cache::new(false, &config, vb.pp("rot"))?;
let model = Llama::load(vb, &cache, config)?; let model = Llama::load(vb, config)?;
let tokens = match &args.pretokenized_dir { let tokens = match &args.pretokenized_dir {
None => { None => {
@ -235,7 +235,7 @@ fn run_eval(args: &EvaluationCmd, common_args: &Args) -> Result<()> {
let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size); let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
for inp_tgt in batch_iter { for inp_tgt in batch_iter {
let (inp, tgt) = inp_tgt?; let (inp, tgt) = inp_tgt?;
let logits = model.forward(&inp, 0)?; let logits = model.forward(&inp, 0, &mut cache)?;
let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?; let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?;
println!("{}", loss.to_vec0::<f32>()?); println!("{}", loss.to_vec0::<f32>()?);
} }
@ -261,7 +261,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
let is_safetensors = config_path let is_safetensors = config_path
.extension() .extension()
.map_or(false, |v| v == "safetensors"); .map_or(false, |v| v == "safetensors");
let (model, config) = if is_gguf { let (model, config, mut cache) = if is_gguf {
let vb = qmodel::VarBuilder::from_gguf(config_path, &device)?; let vb = qmodel::VarBuilder::from_gguf(config_path, &device)?;
let (_vocab_size, dim) = vb let (_vocab_size, dim) = vb
.get_no_shape("model.embed_tokens.weight")? .get_no_shape("model.embed_tokens.weight")?
@ -298,15 +298,15 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
&device, &device,
); );
let cache = model::Cache::new(true, &config, fake_vb)?; let cache = model::Cache::new(true, &config, fake_vb)?;
let model = Model::QLlama(QLlama::load(vb, &cache, config.clone())?); let model = Model::QLlama(QLlama::load(vb, config.clone())?);
(model, config) (model, config, cache)
} else if is_safetensors { } else if is_safetensors {
let config = Config::tiny_15m(); let config = Config::tiny_15m();
let tensors = candle::safetensors::load(config_path, &device)?; let tensors = candle::safetensors::load(config_path, &device)?;
let vb = candle_nn::VarBuilder::from_tensors(tensors, candle::DType::F32, &device); let vb = candle_nn::VarBuilder::from_tensors(tensors, candle::DType::F32, &device);
let cache = model::Cache::new(true, &config, vb.pp("rot"))?; let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
let model = Model::Llama(Llama::load(vb, &cache, config.clone())?); let model = Model::Llama(Llama::load(vb, config.clone())?);
(model, config) (model, config, cache)
} else { } else {
let mut file = std::fs::File::open(config_path)?; let mut file = std::fs::File::open(config_path)?;
let config = Config::from_reader(&mut file)?; let config = Config::from_reader(&mut file)?;
@ -314,8 +314,8 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
let weights = TransformerWeights::from_reader(&mut file, &config, &device)?; let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
let vb = weights.var_builder(&config, &device)?; let vb = weights.var_builder(&config, &device)?;
let cache = model::Cache::new(true, &config, vb.pp("rot"))?; let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
let model = Model::Llama(Llama::load(vb, &cache, config.clone())?); let model = Model::Llama(Llama::load(vb, config.clone())?);
(model, config) (model, config, cache)
}; };
println!("starting the inference loop"); println!("starting the inference loop");
@ -338,7 +338,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
let context_size = if index > 0 { 1 } else { tokens.len() }; let context_size = if index > 0 { 1 } else { tokens.len() };
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..]; let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?; let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
let logits = model.forward(&input, index_pos)?; let logits = model.forward(&input, index_pos, &mut cache)?;
let logits = logits.i((0, logits.dim(1)? - 1))?; let logits = logits.i((0, logits.dim(1)? - 1))?;
let logits = if common_args.repeat_penalty == 1. || tokens.is_empty() { let logits = if common_args.repeat_penalty == 1. || tokens.is_empty() {
logits logits

View File

@ -8,6 +8,7 @@ fn valid_loss(
model: &Llama, model: &Llama,
args: &crate::TrainingCmd, args: &crate::TrainingCmd,
device: &Device, device: &Device,
cache: &mut Cache,
) -> Result<f64> { ) -> Result<f64> {
let iter = DatasetRandomIter::new(dataset, true, model.config.seq_len, device.clone()); let iter = DatasetRandomIter::new(dataset, true, model.config.seq_len, device.clone());
let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size); let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
@ -15,7 +16,7 @@ fn valid_loss(
let mut cnt = 0usize; let mut cnt = 0usize;
for inp_tgt in batch_iter.take(50) { for inp_tgt in batch_iter.take(50) {
let (inp, tgt) = inp_tgt?; let (inp, tgt) = inp_tgt?;
let logits = model.forward(&inp, 0)?; let logits = model.forward(&inp, 0, cache)?;
let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?; let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?;
sum_ce += loss.to_vec0::<f32>()? as f64; sum_ce += loss.to_vec0::<f32>()? as f64;
cnt += 1; cnt += 1;
@ -37,8 +38,8 @@ pub fn run(args: &crate::TrainingCmd, common_args: &crate::Args) -> Result<()> {
let iter = DatasetRandomIter::new(&dataset, false, config.seq_len, device.clone()); let iter = DatasetRandomIter::new(&dataset, false, config.seq_len, device.clone());
let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size); let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
let cache = Cache::new(false, &config, vb.pp("rot"))?; let mut cache = Cache::new(false, &config, vb.pp("rot"))?;
let model = Llama::load(vb, &cache, config)?; let model = Llama::load(vb, config)?;
let params = candle_nn::ParamsAdamW { let params = candle_nn::ParamsAdamW {
lr: args.learning_rate, lr: args.learning_rate,
..Default::default() ..Default::default()
@ -46,14 +47,14 @@ pub fn run(args: &crate::TrainingCmd, common_args: &crate::Args) -> Result<()> {
let mut opt = candle_nn::AdamW::new(varmap.all_vars(), params)?; let mut opt = candle_nn::AdamW::new(varmap.all_vars(), params)?;
for (batch_index, batch) in batch_iter.enumerate() { for (batch_index, batch) in batch_iter.enumerate() {
let (inp, tgt) = batch?; let (inp, tgt) = batch?;
let logits = model.forward(&inp, 0)?; let logits = model.forward(&inp, 0, &mut cache)?;
let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?; let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?;
opt.backward_step(&loss)?; opt.backward_step(&loss)?;
if batch_index > 0 && batch_index % 100 == 0 { if batch_index > 0 && batch_index % 100 == 0 {
// TODO: Add a way to deactivate the backprop graph tracking when computing the // TODO: Add a way to deactivate the backprop graph tracking when computing the
// validation loss. // validation loss.
let loss = valid_loss(&dataset, &model, args, &device)?; let loss = valid_loss(&dataset, &model, args, &device, &mut cache)?;
println!("{batch_index} {loss}"); println!("{batch_index} {loss}");
} }
if batch_index > 0 && batch_index % 1000 == 0 { if batch_index > 0 && batch_index % 1000 == 0 {

View File

@ -2,7 +2,6 @@ use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::linear_no_bias as linear; use candle_nn::linear_no_bias as linear;
use candle_nn::{embedding, rms_norm, Embedding, Linear, Module, RmsNorm, VarBuilder}; use candle_nn::{embedding, rms_norm, Embedding, Linear, Module, RmsNorm, VarBuilder};
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::{Arc, Mutex};
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Config { pub struct Config {
@ -70,12 +69,11 @@ impl Config {
} }
} }
#[derive(Clone)] #[derive(Debug, Clone)]
pub struct Cache { pub struct Cache {
masks: Arc<Mutex<HashMap<usize, Tensor>>>, masks: HashMap<usize, Tensor>,
pub use_kv_cache: bool, pub use_kv_cache: bool,
#[allow(clippy::type_complexity)] pub kvs: Vec<Option<(Tensor, Tensor)>>,
pub kvs: Arc<Mutex<Vec<Option<(Tensor, Tensor)>>>>,
pub cos: Tensor, pub cos: Tensor,
pub sin: Tensor, pub sin: Tensor,
device: Device, device: Device,
@ -105,25 +103,24 @@ impl Cache {
let cos = freq_cis_real.reshape((cfg.seq_len, cfg.head_size() / 2, 1))?; let cos = freq_cis_real.reshape((cfg.seq_len, cfg.head_size() / 2, 1))?;
let sin = freq_cis_imag.reshape((cfg.seq_len, cfg.head_size() / 2, 1))?; let sin = freq_cis_imag.reshape((cfg.seq_len, cfg.head_size() / 2, 1))?;
Ok(Self { Ok(Self {
masks: Arc::new(Mutex::new(HashMap::new())), masks: HashMap::new(),
use_kv_cache, use_kv_cache,
kvs: Arc::new(Mutex::new(vec![None; cfg.n_layers])), kvs: vec![None; cfg.n_layers],
cos, cos,
sin, sin,
device: vb.device().clone(), device: vb.device().clone(),
}) })
} }
pub fn mask(&self, t: usize) -> Result<Tensor> { pub fn mask(&mut self, t: usize) -> Result<Tensor> {
let mut masks = self.masks.lock().unwrap(); if let Some(mask) = self.masks.get(&t) {
if let Some(mask) = masks.get(&t) {
Ok(mask.clone()) Ok(mask.clone())
} else { } else {
let mask: Vec<_> = (0..t) let mask: Vec<_> = (0..t)
.flat_map(|i| (0..t).map(move |j| u8::from(j > i))) .flat_map(|i| (0..t).map(move |j| u8::from(j > i)))
.collect(); .collect();
let mask = Tensor::from_slice(&mask, (t, t), &self.device)?; let mask = Tensor::from_slice(&mask, (t, t), &self.device)?;
masks.insert(t, mask.clone()); self.masks.insert(t, mask.clone());
Ok(mask) Ok(mask)
} }
} }
@ -133,6 +130,7 @@ fn silu(xs: &Tensor) -> Result<Tensor> {
xs / (xs.neg()?.exp()? + 1.0)? xs / (xs.neg()?.exp()? + 1.0)?
} }
#[derive(Debug, Clone)]
struct CausalSelfAttention { struct CausalSelfAttention {
q_proj: Linear, q_proj: Linear,
k_proj: Linear, k_proj: Linear,
@ -141,14 +139,13 @@ struct CausalSelfAttention {
n_head: usize, n_head: usize,
n_key_value_head: usize, n_key_value_head: usize,
head_dim: usize, head_dim: usize,
cache: Cache,
} }
impl CausalSelfAttention { impl CausalSelfAttention {
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> { fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize, cache: &Cache) -> Result<Tensor> {
let (b_sz, seq_len, h, n_embd) = x.dims4()?; let (b_sz, seq_len, h, n_embd) = x.dims4()?;
let cos = self.cache.cos.i(index_pos..index_pos + seq_len)?; let cos = cache.cos.i(index_pos..index_pos + seq_len)?;
let sin = self.cache.sin.i(index_pos..index_pos + seq_len)?; let sin = cache.sin.i(index_pos..index_pos + seq_len)?;
let cos = cos.unsqueeze(1)?; let cos = cos.unsqueeze(1)?;
let sin = sin.unsqueeze(1)?; let sin = sin.unsqueeze(1)?;
let cos = cos.broadcast_as((b_sz, seq_len, 1, n_embd / 2, 1))?; let cos = cos.broadcast_as((b_sz, seq_len, 1, n_embd / 2, 1))?;
@ -162,7 +159,13 @@ impl CausalSelfAttention {
Ok(rope) Ok(rope)
} }
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> { fn forward(
&self,
x: &Tensor,
index_pos: usize,
block_idx: usize,
cache: &mut Cache,
) -> Result<Tensor> {
let (b_sz, seq_len, n_embd) = x.dims3()?; let (b_sz, seq_len, n_embd) = x.dims3()?;
let q = self.q_proj.forward(x)?; let q = self.q_proj.forward(x)?;
let k = self.k_proj.forward(x)?; let k = self.k_proj.forward(x)?;
@ -172,16 +175,15 @@ impl CausalSelfAttention {
let k = k.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?; let k = k.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?;
let mut v = v.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?; let mut v = v.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?;
let q = self.apply_rotary_emb(&q, index_pos)?; let q = self.apply_rotary_emb(&q, index_pos, cache)?;
let mut k = self.apply_rotary_emb(&k, index_pos)?; let mut k = self.apply_rotary_emb(&k, index_pos, cache)?;
if self.cache.use_kv_cache { if cache.use_kv_cache {
let mut cache = self.cache.kvs.lock().unwrap(); if let Some((cache_k, cache_v)) = &cache.kvs[block_idx] {
if let Some((cache_k, cache_v)) = &cache[block_idx] {
k = Tensor::cat(&[cache_k, &k], 1)?.contiguous()?; k = Tensor::cat(&[cache_k, &k], 1)?.contiguous()?;
v = Tensor::cat(&[cache_v, &v], 1)?.contiguous()?; v = Tensor::cat(&[cache_v, &v], 1)?.contiguous()?;
} }
cache[block_idx] = Some((k.clone(), v.clone())) cache.kvs[block_idx] = Some((k.clone(), v.clone()))
} }
let k = self.repeat_kv(k)?; let k = self.repeat_kv(k)?;
@ -192,7 +194,7 @@ impl CausalSelfAttention {
let v = v.transpose(1, 2)?.contiguous()?; let v = v.transpose(1, 2)?.contiguous()?;
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?; let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?; let mask = cache.mask(seq_len)?.broadcast_as(att.shape())?;
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?; let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
let att = candle_nn::ops::softmax(&att, D::Minus1)?; let att = candle_nn::ops::softmax(&att, D::Minus1)?;
// Convert to contiguous as matmul doesn't support strided vs for now. // Convert to contiguous as matmul doesn't support strided vs for now.
@ -216,7 +218,7 @@ impl CausalSelfAttention {
} }
} }
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> { fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let size_in = cfg.dim; let size_in = cfg.dim;
let size_q = (cfg.dim / cfg.n_heads) * cfg.n_heads; let size_q = (cfg.dim / cfg.n_heads) * cfg.n_heads;
let size_kv = (cfg.dim / cfg.n_heads) * cfg.n_kv_heads; let size_kv = (cfg.dim / cfg.n_heads) * cfg.n_kv_heads;
@ -232,7 +234,6 @@ impl CausalSelfAttention {
n_head: cfg.n_heads, n_head: cfg.n_heads,
n_key_value_head: cfg.n_kv_heads, n_key_value_head: cfg.n_kv_heads,
head_dim: cfg.dim / cfg.n_heads, head_dim: cfg.dim / cfg.n_heads,
cache: cache.clone(),
}) })
} }
} }
@ -244,6 +245,7 @@ fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor>
Ok(m) Ok(m)
} }
#[derive(Debug, Clone)]
struct Mlp { struct Mlp {
c_fc1: Linear, c_fc1: Linear,
c_fc2: Linear, c_fc2: Linear,
@ -274,6 +276,7 @@ impl Mlp {
} }
} }
#[derive(Debug, Clone)]
struct Block { struct Block {
rms_1: RmsNorm, rms_1: RmsNorm,
attn: CausalSelfAttention, attn: CausalSelfAttention,
@ -291,17 +294,23 @@ impl Block {
} }
} }
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> { fn forward(
&self,
x: &Tensor,
index_pos: usize,
block_idx: usize,
cache: &mut Cache,
) -> Result<Tensor> {
let residual = x; let residual = x;
let x = self.rms_1.forward(x)?; let x = self.rms_1.forward(x)?;
let x = (self.attn.forward(&x, index_pos, block_idx)? + residual)?; let x = (self.attn.forward(&x, index_pos, block_idx, cache)? + residual)?;
let residual = &x; let residual = &x;
let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?; let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?;
Ok(x) Ok(x)
} }
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> { fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg)?; let attn = CausalSelfAttention::load(vb.pp("self_attn"), cfg)?;
let mlp = Mlp::load(vb.pp("mlp"), cfg)?; let mlp = Mlp::load(vb.pp("mlp"), cfg)?;
let input_layernorm = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("input_layernorm"))?; let input_layernorm = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("input_layernorm"))?;
let post_attention_layernorm = let post_attention_layernorm =
@ -315,6 +324,7 @@ impl Block {
} }
} }
#[derive(Debug, Clone)]
pub struct Llama { pub struct Llama {
wte: Embedding, wte: Embedding,
blocks: Vec<Block>, blocks: Vec<Block>,
@ -324,23 +334,23 @@ pub struct Llama {
} }
impl Llama { impl Llama {
pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> { pub fn forward(&self, x: &Tensor, index_pos: usize, cache: &mut Cache) -> Result<Tensor> {
let (_b_sz, _seq_len) = x.dims2()?; let (_b_sz, _seq_len) = x.dims2()?;
let mut x = self.wte.forward(x)?; let mut x = self.wte.forward(x)?;
for (block_idx, block) in self.blocks.iter().enumerate() { for (block_idx, block) in self.blocks.iter().enumerate() {
x = block.forward(&x, index_pos, block_idx)?; x = block.forward(&x, index_pos, block_idx, cache)?;
} }
let x = self.ln_f.forward(&x)?; let x = self.ln_f.forward(&x)?;
let logits = self.lm_head.forward(&x)?; let logits = self.lm_head.forward(&x)?;
logits.to_dtype(DType::F32) logits.to_dtype(DType::F32)
} }
pub fn load(vb: VarBuilder, cache: &Cache, cfg: Config) -> Result<Self> { pub fn load(vb: VarBuilder, cfg: Config) -> Result<Self> {
let wte = embedding(cfg.vocab_size, cfg.dim, vb.pp("model.embed_tokens"))?; let wte = embedding(cfg.vocab_size, cfg.dim, vb.pp("model.embed_tokens"))?;
let lm_head = linear(cfg.dim, cfg.vocab_size, vb.pp("lm_head"))?; let lm_head = linear(cfg.dim, cfg.vocab_size, vb.pp("lm_head"))?;
let ln_f = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?; let ln_f = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?;
let blocks: Vec<_> = (0..cfg.n_layers) let blocks: Vec<_> = (0..cfg.n_layers)
.map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), cache, &cfg).unwrap()) .map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), &cfg).unwrap())
.collect(); .collect();
Ok(Self { Ok(Self {
wte, wte,

View File

@ -7,6 +7,7 @@ fn silu(xs: &Tensor) -> Result<Tensor> {
xs / (xs.neg()?.exp()? + 1.0)? xs / (xs.neg()?.exp()? + 1.0)?
} }
#[derive(Debug, Clone)]
struct CausalSelfAttention { struct CausalSelfAttention {
q_proj: Linear, q_proj: Linear,
k_proj: Linear, k_proj: Linear,
@ -15,14 +16,13 @@ struct CausalSelfAttention {
n_head: usize, n_head: usize,
n_key_value_head: usize, n_key_value_head: usize,
head_dim: usize, head_dim: usize,
cache: Cache,
} }
impl CausalSelfAttention { impl CausalSelfAttention {
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> { fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize, cache: &Cache) -> Result<Tensor> {
let (b_sz, seq_len, h, n_embd) = x.dims4()?; let (b_sz, seq_len, h, n_embd) = x.dims4()?;
let cos = self.cache.cos.i(index_pos..index_pos + seq_len)?; let cos = cache.cos.i(index_pos..index_pos + seq_len)?;
let sin = self.cache.sin.i(index_pos..index_pos + seq_len)?; let sin = cache.sin.i(index_pos..index_pos + seq_len)?;
let cos = cos.unsqueeze(1)?; let cos = cos.unsqueeze(1)?;
let sin = sin.unsqueeze(1)?; let sin = sin.unsqueeze(1)?;
let cos = cos.broadcast_as((b_sz, seq_len, 1, n_embd / 2, 1))?; let cos = cos.broadcast_as((b_sz, seq_len, 1, n_embd / 2, 1))?;
@ -36,7 +36,13 @@ impl CausalSelfAttention {
Ok(rope) Ok(rope)
} }
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> { fn forward(
&self,
x: &Tensor,
index_pos: usize,
block_idx: usize,
cache: &mut Cache,
) -> Result<Tensor> {
let (b_sz, seq_len, n_embd) = x.dims3()?; let (b_sz, seq_len, n_embd) = x.dims3()?;
let q = self.q_proj.forward(x)?; let q = self.q_proj.forward(x)?;
let k = self.k_proj.forward(x)?; let k = self.k_proj.forward(x)?;
@ -46,16 +52,15 @@ impl CausalSelfAttention {
let k = k.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?; let k = k.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?;
let mut v = v.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?; let mut v = v.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?;
let q = self.apply_rotary_emb(&q, index_pos)?; let q = self.apply_rotary_emb(&q, index_pos, cache)?;
let mut k = self.apply_rotary_emb(&k, index_pos)?; let mut k = self.apply_rotary_emb(&k, index_pos, cache)?;
if self.cache.use_kv_cache { if cache.use_kv_cache {
let mut cache = self.cache.kvs.lock().unwrap(); if let Some((cache_k, cache_v)) = &cache.kvs[block_idx] {
if let Some((cache_k, cache_v)) = &cache[block_idx] {
k = Tensor::cat(&[cache_k, &k], 1)?.contiguous()?; k = Tensor::cat(&[cache_k, &k], 1)?.contiguous()?;
v = Tensor::cat(&[cache_v, &v], 1)?.contiguous()?; v = Tensor::cat(&[cache_v, &v], 1)?.contiguous()?;
} }
cache[block_idx] = Some((k.clone(), v.clone())) cache.kvs[block_idx] = Some((k.clone(), v.clone()))
} }
let k = self.repeat_kv(k)?; let k = self.repeat_kv(k)?;
@ -66,7 +71,7 @@ impl CausalSelfAttention {
let v = v.transpose(1, 2)?.contiguous()?; let v = v.transpose(1, 2)?.contiguous()?;
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?; let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?; let mask = cache.mask(seq_len)?.broadcast_as(att.shape())?;
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?; let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
let att = candle_nn::ops::softmax(&att, D::Minus1)?; let att = candle_nn::ops::softmax(&att, D::Minus1)?;
// Convert to contiguous as matmul doesn't support strided vs for now. // Convert to contiguous as matmul doesn't support strided vs for now.
@ -90,7 +95,7 @@ impl CausalSelfAttention {
} }
} }
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> { fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let size_in = cfg.dim; let size_in = cfg.dim;
let size_q = (cfg.dim / cfg.n_heads) * cfg.n_heads; let size_q = (cfg.dim / cfg.n_heads) * cfg.n_heads;
let size_kv = (cfg.dim / cfg.n_heads) * cfg.n_kv_heads; let size_kv = (cfg.dim / cfg.n_heads) * cfg.n_kv_heads;
@ -106,7 +111,6 @@ impl CausalSelfAttention {
n_head: cfg.n_heads, n_head: cfg.n_heads,
n_key_value_head: cfg.n_kv_heads, n_key_value_head: cfg.n_kv_heads,
head_dim: cfg.dim / cfg.n_heads, head_dim: cfg.dim / cfg.n_heads,
cache: cache.clone(),
}) })
} }
} }
@ -118,6 +122,7 @@ fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor>
Ok(m) Ok(m)
} }
#[derive(Debug, Clone)]
struct Mlp { struct Mlp {
c_fc1: Linear, c_fc1: Linear,
c_fc2: Linear, c_fc2: Linear,
@ -148,6 +153,7 @@ impl Mlp {
} }
} }
#[derive(Debug, Clone)]
struct Block { struct Block {
rms_1: RmsNorm, rms_1: RmsNorm,
attn: CausalSelfAttention, attn: CausalSelfAttention,
@ -165,17 +171,23 @@ impl Block {
} }
} }
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> { fn forward(
&self,
x: &Tensor,
index_pos: usize,
block_idx: usize,
cache: &mut Cache,
) -> Result<Tensor> {
let residual = x; let residual = x;
let x = self.rms_1.forward(x)?; let x = self.rms_1.forward(x)?;
let x = (self.attn.forward(&x, index_pos, block_idx)? + residual)?; let x = (self.attn.forward(&x, index_pos, block_idx, cache)? + residual)?;
let residual = &x; let residual = &x;
let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?; let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?;
Ok(x) Ok(x)
} }
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> { fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg)?; let attn = CausalSelfAttention::load(vb.pp("self_attn"), cfg)?;
let mlp = Mlp::load(vb.pp("mlp"), cfg)?; let mlp = Mlp::load(vb.pp("mlp"), cfg)?;
let input_layernorm = RmsNorm::new(cfg.dim, cfg.norm_eps, vb.pp("input_layernorm"))?; let input_layernorm = RmsNorm::new(cfg.dim, cfg.norm_eps, vb.pp("input_layernorm"))?;
let post_attention_layernorm = let post_attention_layernorm =
@ -189,6 +201,7 @@ impl Block {
} }
} }
#[derive(Debug, Clone)]
pub struct QLlama { pub struct QLlama {
wte: Embedding, wte: Embedding,
blocks: Vec<Block>, blocks: Vec<Block>,
@ -198,23 +211,23 @@ pub struct QLlama {
} }
impl QLlama { impl QLlama {
pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> { pub fn forward(&self, x: &Tensor, index_pos: usize, cache: &mut Cache) -> Result<Tensor> {
let (_b_sz, _seq_len) = x.dims2()?; let (_b_sz, _seq_len) = x.dims2()?;
let mut x = self.wte.forward(x)?; let mut x = self.wte.forward(x)?;
for (block_idx, block) in self.blocks.iter().enumerate() { for (block_idx, block) in self.blocks.iter().enumerate() {
x = block.forward(&x, index_pos, block_idx)?; x = block.forward(&x, index_pos, block_idx, cache)?;
} }
let x = self.ln_f.forward(&x)?; let x = self.ln_f.forward(&x)?;
let logits = self.lm_head.forward(&x)?; let logits = self.lm_head.forward(&x)?;
logits.to_dtype(DType::F32) logits.to_dtype(DType::F32)
} }
pub fn load(vb: VarBuilder, cache: &Cache, cfg: Config) -> Result<Self> { pub fn load(vb: VarBuilder, cfg: Config) -> Result<Self> {
let wte = Embedding::new(cfg.vocab_size, cfg.dim, vb.pp("model.embed_tokens"))?; let wte = Embedding::new(cfg.vocab_size, cfg.dim, vb.pp("model.embed_tokens"))?;
let lm_head = linear(cfg.dim, cfg.vocab_size, vb.pp("lm_head"))?; let lm_head = linear(cfg.dim, cfg.vocab_size, vb.pp("lm_head"))?;
let ln_f = RmsNorm::new(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?; let ln_f = RmsNorm::new(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?;
let blocks: Vec<_> = (0..cfg.n_layers) let blocks: Vec<_> = (0..cfg.n_layers)
.map(|i| Block::load(vb.pp(format!("model.layers.{i}")), cache, &cfg).unwrap()) .map(|i| Block::load(vb.pp(format!("model.layers.{i}")), &cfg).unwrap())
.collect(); .collect();
Ok(Self { Ok(Self {
wte, wte,