mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Explicit caching in llama2.c.
This commit is contained in:
@ -19,7 +19,7 @@ use candle_transformers::generation::LogitsProcessor;
|
||||
use std::io::Write;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
use model::{Config, Llama};
|
||||
use model::{Cache, Config, Llama};
|
||||
use qmodel::QLlama;
|
||||
use weights::TransformerWeights;
|
||||
|
||||
@ -160,10 +160,10 @@ enum Model {
|
||||
}
|
||||
|
||||
impl Model {
|
||||
fn forward(&self, xs: &Tensor, pos: usize) -> anyhow::Result<Tensor> {
|
||||
fn forward(&self, xs: &Tensor, pos: usize, cache: &mut Cache) -> anyhow::Result<Tensor> {
|
||||
match self {
|
||||
Self::Llama(l) => Ok(l.forward(xs, pos)?),
|
||||
Self::QLlama(l) => Ok(l.forward(xs, pos)?),
|
||||
Self::Llama(l) => Ok(l.forward(xs, pos, cache)?),
|
||||
Self::QLlama(l) => Ok(l.forward(xs, pos, cache)?),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -188,8 +188,8 @@ fn run_eval(args: &EvaluationCmd, common_args: &Args) -> Result<()> {
|
||||
let config = Config::from_reader(&mut file)?;
|
||||
let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
|
||||
let vb = weights.var_builder(&config, &device)?;
|
||||
let cache = model::Cache::new(false, &config, vb.pp("rot"))?;
|
||||
let model = Llama::load(vb, &cache, config)?;
|
||||
let mut cache = Cache::new(false, &config, vb.pp("rot"))?;
|
||||
let model = Llama::load(vb, config)?;
|
||||
|
||||
let tokens = match &args.pretokenized_dir {
|
||||
None => {
|
||||
@ -235,7 +235,7 @@ fn run_eval(args: &EvaluationCmd, common_args: &Args) -> Result<()> {
|
||||
let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
|
||||
for inp_tgt in batch_iter {
|
||||
let (inp, tgt) = inp_tgt?;
|
||||
let logits = model.forward(&inp, 0)?;
|
||||
let logits = model.forward(&inp, 0, &mut cache)?;
|
||||
let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?;
|
||||
println!("{}", loss.to_vec0::<f32>()?);
|
||||
}
|
||||
@ -261,7 +261,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
||||
let is_safetensors = config_path
|
||||
.extension()
|
||||
.map_or(false, |v| v == "safetensors");
|
||||
let (model, config) = if is_gguf {
|
||||
let (model, config, mut cache) = if is_gguf {
|
||||
let vb = qmodel::VarBuilder::from_gguf(config_path, &device)?;
|
||||
let (_vocab_size, dim) = vb
|
||||
.get_no_shape("model.embed_tokens.weight")?
|
||||
@ -298,15 +298,15 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
||||
&device,
|
||||
);
|
||||
let cache = model::Cache::new(true, &config, fake_vb)?;
|
||||
let model = Model::QLlama(QLlama::load(vb, &cache, config.clone())?);
|
||||
(model, config)
|
||||
let model = Model::QLlama(QLlama::load(vb, config.clone())?);
|
||||
(model, config, cache)
|
||||
} else if is_safetensors {
|
||||
let config = Config::tiny_15m();
|
||||
let tensors = candle::safetensors::load(config_path, &device)?;
|
||||
let vb = candle_nn::VarBuilder::from_tensors(tensors, candle::DType::F32, &device);
|
||||
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
|
||||
let model = Model::Llama(Llama::load(vb, &cache, config.clone())?);
|
||||
(model, config)
|
||||
let model = Model::Llama(Llama::load(vb, config.clone())?);
|
||||
(model, config, cache)
|
||||
} else {
|
||||
let mut file = std::fs::File::open(config_path)?;
|
||||
let config = Config::from_reader(&mut file)?;
|
||||
@ -314,8 +314,8 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
||||
let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
|
||||
let vb = weights.var_builder(&config, &device)?;
|
||||
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
|
||||
let model = Model::Llama(Llama::load(vb, &cache, config.clone())?);
|
||||
(model, config)
|
||||
let model = Model::Llama(Llama::load(vb, config.clone())?);
|
||||
(model, config, cache)
|
||||
};
|
||||
|
||||
println!("starting the inference loop");
|
||||
@ -338,7 +338,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
|
||||
let logits = model.forward(&input, index_pos)?;
|
||||
let logits = model.forward(&input, index_pos, &mut cache)?;
|
||||
let logits = logits.i((0, logits.dim(1)? - 1))?;
|
||||
let logits = if common_args.repeat_penalty == 1. || tokens.is_empty() {
|
||||
logits
|
||||
|
@ -8,6 +8,7 @@ fn valid_loss(
|
||||
model: &Llama,
|
||||
args: &crate::TrainingCmd,
|
||||
device: &Device,
|
||||
cache: &mut Cache,
|
||||
) -> Result<f64> {
|
||||
let iter = DatasetRandomIter::new(dataset, true, model.config.seq_len, device.clone());
|
||||
let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
|
||||
@ -15,7 +16,7 @@ fn valid_loss(
|
||||
let mut cnt = 0usize;
|
||||
for inp_tgt in batch_iter.take(50) {
|
||||
let (inp, tgt) = inp_tgt?;
|
||||
let logits = model.forward(&inp, 0)?;
|
||||
let logits = model.forward(&inp, 0, cache)?;
|
||||
let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?;
|
||||
sum_ce += loss.to_vec0::<f32>()? as f64;
|
||||
cnt += 1;
|
||||
@ -37,8 +38,8 @@ pub fn run(args: &crate::TrainingCmd, common_args: &crate::Args) -> Result<()> {
|
||||
let iter = DatasetRandomIter::new(&dataset, false, config.seq_len, device.clone());
|
||||
let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
|
||||
|
||||
let cache = Cache::new(false, &config, vb.pp("rot"))?;
|
||||
let model = Llama::load(vb, &cache, config)?;
|
||||
let mut cache = Cache::new(false, &config, vb.pp("rot"))?;
|
||||
let model = Llama::load(vb, config)?;
|
||||
let params = candle_nn::ParamsAdamW {
|
||||
lr: args.learning_rate,
|
||||
..Default::default()
|
||||
@ -46,14 +47,14 @@ pub fn run(args: &crate::TrainingCmd, common_args: &crate::Args) -> Result<()> {
|
||||
let mut opt = candle_nn::AdamW::new(varmap.all_vars(), params)?;
|
||||
for (batch_index, batch) in batch_iter.enumerate() {
|
||||
let (inp, tgt) = batch?;
|
||||
let logits = model.forward(&inp, 0)?;
|
||||
let logits = model.forward(&inp, 0, &mut cache)?;
|
||||
let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?;
|
||||
opt.backward_step(&loss)?;
|
||||
|
||||
if batch_index > 0 && batch_index % 100 == 0 {
|
||||
// TODO: Add a way to deactivate the backprop graph tracking when computing the
|
||||
// validation loss.
|
||||
let loss = valid_loss(&dataset, &model, args, &device)?;
|
||||
let loss = valid_loss(&dataset, &model, args, &device, &mut cache)?;
|
||||
println!("{batch_index} {loss}");
|
||||
}
|
||||
if batch_index > 0 && batch_index % 1000 == 0 {
|
||||
|
@ -2,7 +2,6 @@ use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::linear_no_bias as linear;
|
||||
use candle_nn::{embedding, rms_norm, Embedding, Linear, Module, RmsNorm, VarBuilder};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Config {
|
||||
@ -70,12 +69,11 @@ impl Config {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Cache {
|
||||
masks: Arc<Mutex<HashMap<usize, Tensor>>>,
|
||||
masks: HashMap<usize, Tensor>,
|
||||
pub use_kv_cache: bool,
|
||||
#[allow(clippy::type_complexity)]
|
||||
pub kvs: Arc<Mutex<Vec<Option<(Tensor, Tensor)>>>>,
|
||||
pub kvs: Vec<Option<(Tensor, Tensor)>>,
|
||||
pub cos: Tensor,
|
||||
pub sin: Tensor,
|
||||
device: Device,
|
||||
@ -105,25 +103,24 @@ impl Cache {
|
||||
let cos = freq_cis_real.reshape((cfg.seq_len, cfg.head_size() / 2, 1))?;
|
||||
let sin = freq_cis_imag.reshape((cfg.seq_len, cfg.head_size() / 2, 1))?;
|
||||
Ok(Self {
|
||||
masks: Arc::new(Mutex::new(HashMap::new())),
|
||||
masks: HashMap::new(),
|
||||
use_kv_cache,
|
||||
kvs: Arc::new(Mutex::new(vec![None; cfg.n_layers])),
|
||||
kvs: vec![None; cfg.n_layers],
|
||||
cos,
|
||||
sin,
|
||||
device: vb.device().clone(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn mask(&self, t: usize) -> Result<Tensor> {
|
||||
let mut masks = self.masks.lock().unwrap();
|
||||
if let Some(mask) = masks.get(&t) {
|
||||
pub fn mask(&mut self, t: usize) -> Result<Tensor> {
|
||||
if let Some(mask) = self.masks.get(&t) {
|
||||
Ok(mask.clone())
|
||||
} else {
|
||||
let mask: Vec<_> = (0..t)
|
||||
.flat_map(|i| (0..t).map(move |j| u8::from(j > i)))
|
||||
.collect();
|
||||
let mask = Tensor::from_slice(&mask, (t, t), &self.device)?;
|
||||
masks.insert(t, mask.clone());
|
||||
self.masks.insert(t, mask.clone());
|
||||
Ok(mask)
|
||||
}
|
||||
}
|
||||
@ -133,6 +130,7 @@ fn silu(xs: &Tensor) -> Result<Tensor> {
|
||||
xs / (xs.neg()?.exp()? + 1.0)?
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct CausalSelfAttention {
|
||||
q_proj: Linear,
|
||||
k_proj: Linear,
|
||||
@ -141,14 +139,13 @@ struct CausalSelfAttention {
|
||||
n_head: usize,
|
||||
n_key_value_head: usize,
|
||||
head_dim: usize,
|
||||
cache: Cache,
|
||||
}
|
||||
|
||||
impl CausalSelfAttention {
|
||||
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize, cache: &Cache) -> Result<Tensor> {
|
||||
let (b_sz, seq_len, h, n_embd) = x.dims4()?;
|
||||
let cos = self.cache.cos.i(index_pos..index_pos + seq_len)?;
|
||||
let sin = self.cache.sin.i(index_pos..index_pos + seq_len)?;
|
||||
let cos = cache.cos.i(index_pos..index_pos + seq_len)?;
|
||||
let sin = cache.sin.i(index_pos..index_pos + seq_len)?;
|
||||
let cos = cos.unsqueeze(1)?;
|
||||
let sin = sin.unsqueeze(1)?;
|
||||
let cos = cos.broadcast_as((b_sz, seq_len, 1, n_embd / 2, 1))?;
|
||||
@ -162,7 +159,13 @@ impl CausalSelfAttention {
|
||||
Ok(rope)
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
|
||||
fn forward(
|
||||
&self,
|
||||
x: &Tensor,
|
||||
index_pos: usize,
|
||||
block_idx: usize,
|
||||
cache: &mut Cache,
|
||||
) -> Result<Tensor> {
|
||||
let (b_sz, seq_len, n_embd) = x.dims3()?;
|
||||
let q = self.q_proj.forward(x)?;
|
||||
let k = self.k_proj.forward(x)?;
|
||||
@ -172,16 +175,15 @@ impl CausalSelfAttention {
|
||||
let k = k.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?;
|
||||
let mut v = v.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?;
|
||||
|
||||
let q = self.apply_rotary_emb(&q, index_pos)?;
|
||||
let mut k = self.apply_rotary_emb(&k, index_pos)?;
|
||||
let q = self.apply_rotary_emb(&q, index_pos, cache)?;
|
||||
let mut k = self.apply_rotary_emb(&k, index_pos, cache)?;
|
||||
|
||||
if self.cache.use_kv_cache {
|
||||
let mut cache = self.cache.kvs.lock().unwrap();
|
||||
if let Some((cache_k, cache_v)) = &cache[block_idx] {
|
||||
if cache.use_kv_cache {
|
||||
if let Some((cache_k, cache_v)) = &cache.kvs[block_idx] {
|
||||
k = Tensor::cat(&[cache_k, &k], 1)?.contiguous()?;
|
||||
v = Tensor::cat(&[cache_v, &v], 1)?.contiguous()?;
|
||||
}
|
||||
cache[block_idx] = Some((k.clone(), v.clone()))
|
||||
cache.kvs[block_idx] = Some((k.clone(), v.clone()))
|
||||
}
|
||||
|
||||
let k = self.repeat_kv(k)?;
|
||||
@ -192,7 +194,7 @@ impl CausalSelfAttention {
|
||||
let v = v.transpose(1, 2)?.contiguous()?;
|
||||
|
||||
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
|
||||
let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?;
|
||||
let mask = cache.mask(seq_len)?.broadcast_as(att.shape())?;
|
||||
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
|
||||
let att = candle_nn::ops::softmax(&att, D::Minus1)?;
|
||||
// Convert to contiguous as matmul doesn't support strided vs for now.
|
||||
@ -216,7 +218,7 @@ impl CausalSelfAttention {
|
||||
}
|
||||
}
|
||||
|
||||
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let size_in = cfg.dim;
|
||||
let size_q = (cfg.dim / cfg.n_heads) * cfg.n_heads;
|
||||
let size_kv = (cfg.dim / cfg.n_heads) * cfg.n_kv_heads;
|
||||
@ -232,7 +234,6 @@ impl CausalSelfAttention {
|
||||
n_head: cfg.n_heads,
|
||||
n_key_value_head: cfg.n_kv_heads,
|
||||
head_dim: cfg.dim / cfg.n_heads,
|
||||
cache: cache.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -244,6 +245,7 @@ fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor>
|
||||
Ok(m)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Mlp {
|
||||
c_fc1: Linear,
|
||||
c_fc2: Linear,
|
||||
@ -274,6 +276,7 @@ impl Mlp {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Block {
|
||||
rms_1: RmsNorm,
|
||||
attn: CausalSelfAttention,
|
||||
@ -291,17 +294,23 @@ impl Block {
|
||||
}
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
|
||||
fn forward(
|
||||
&self,
|
||||
x: &Tensor,
|
||||
index_pos: usize,
|
||||
block_idx: usize,
|
||||
cache: &mut Cache,
|
||||
) -> Result<Tensor> {
|
||||
let residual = x;
|
||||
let x = self.rms_1.forward(x)?;
|
||||
let x = (self.attn.forward(&x, index_pos, block_idx)? + residual)?;
|
||||
let x = (self.attn.forward(&x, index_pos, block_idx, cache)? + residual)?;
|
||||
let residual = &x;
|
||||
let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?;
|
||||
Ok(x)
|
||||
}
|
||||
|
||||
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
|
||||
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg)?;
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cfg)?;
|
||||
let mlp = Mlp::load(vb.pp("mlp"), cfg)?;
|
||||
let input_layernorm = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("input_layernorm"))?;
|
||||
let post_attention_layernorm =
|
||||
@ -315,6 +324,7 @@ impl Block {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Llama {
|
||||
wte: Embedding,
|
||||
blocks: Vec<Block>,
|
||||
@ -324,23 +334,23 @@ pub struct Llama {
|
||||
}
|
||||
|
||||
impl Llama {
|
||||
pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||
pub fn forward(&self, x: &Tensor, index_pos: usize, cache: &mut Cache) -> Result<Tensor> {
|
||||
let (_b_sz, _seq_len) = x.dims2()?;
|
||||
let mut x = self.wte.forward(x)?;
|
||||
for (block_idx, block) in self.blocks.iter().enumerate() {
|
||||
x = block.forward(&x, index_pos, block_idx)?;
|
||||
x = block.forward(&x, index_pos, block_idx, cache)?;
|
||||
}
|
||||
let x = self.ln_f.forward(&x)?;
|
||||
let logits = self.lm_head.forward(&x)?;
|
||||
logits.to_dtype(DType::F32)
|
||||
}
|
||||
|
||||
pub fn load(vb: VarBuilder, cache: &Cache, cfg: Config) -> Result<Self> {
|
||||
pub fn load(vb: VarBuilder, cfg: Config) -> Result<Self> {
|
||||
let wte = embedding(cfg.vocab_size, cfg.dim, vb.pp("model.embed_tokens"))?;
|
||||
let lm_head = linear(cfg.dim, cfg.vocab_size, vb.pp("lm_head"))?;
|
||||
let ln_f = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?;
|
||||
let blocks: Vec<_> = (0..cfg.n_layers)
|
||||
.map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), cache, &cfg).unwrap())
|
||||
.map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), &cfg).unwrap())
|
||||
.collect();
|
||||
Ok(Self {
|
||||
wte,
|
||||
|
@ -7,6 +7,7 @@ fn silu(xs: &Tensor) -> Result<Tensor> {
|
||||
xs / (xs.neg()?.exp()? + 1.0)?
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct CausalSelfAttention {
|
||||
q_proj: Linear,
|
||||
k_proj: Linear,
|
||||
@ -15,14 +16,13 @@ struct CausalSelfAttention {
|
||||
n_head: usize,
|
||||
n_key_value_head: usize,
|
||||
head_dim: usize,
|
||||
cache: Cache,
|
||||
}
|
||||
|
||||
impl CausalSelfAttention {
|
||||
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize, cache: &Cache) -> Result<Tensor> {
|
||||
let (b_sz, seq_len, h, n_embd) = x.dims4()?;
|
||||
let cos = self.cache.cos.i(index_pos..index_pos + seq_len)?;
|
||||
let sin = self.cache.sin.i(index_pos..index_pos + seq_len)?;
|
||||
let cos = cache.cos.i(index_pos..index_pos + seq_len)?;
|
||||
let sin = cache.sin.i(index_pos..index_pos + seq_len)?;
|
||||
let cos = cos.unsqueeze(1)?;
|
||||
let sin = sin.unsqueeze(1)?;
|
||||
let cos = cos.broadcast_as((b_sz, seq_len, 1, n_embd / 2, 1))?;
|
||||
@ -36,7 +36,13 @@ impl CausalSelfAttention {
|
||||
Ok(rope)
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
|
||||
fn forward(
|
||||
&self,
|
||||
x: &Tensor,
|
||||
index_pos: usize,
|
||||
block_idx: usize,
|
||||
cache: &mut Cache,
|
||||
) -> Result<Tensor> {
|
||||
let (b_sz, seq_len, n_embd) = x.dims3()?;
|
||||
let q = self.q_proj.forward(x)?;
|
||||
let k = self.k_proj.forward(x)?;
|
||||
@ -46,16 +52,15 @@ impl CausalSelfAttention {
|
||||
let k = k.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?;
|
||||
let mut v = v.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?;
|
||||
|
||||
let q = self.apply_rotary_emb(&q, index_pos)?;
|
||||
let mut k = self.apply_rotary_emb(&k, index_pos)?;
|
||||
let q = self.apply_rotary_emb(&q, index_pos, cache)?;
|
||||
let mut k = self.apply_rotary_emb(&k, index_pos, cache)?;
|
||||
|
||||
if self.cache.use_kv_cache {
|
||||
let mut cache = self.cache.kvs.lock().unwrap();
|
||||
if let Some((cache_k, cache_v)) = &cache[block_idx] {
|
||||
if cache.use_kv_cache {
|
||||
if let Some((cache_k, cache_v)) = &cache.kvs[block_idx] {
|
||||
k = Tensor::cat(&[cache_k, &k], 1)?.contiguous()?;
|
||||
v = Tensor::cat(&[cache_v, &v], 1)?.contiguous()?;
|
||||
}
|
||||
cache[block_idx] = Some((k.clone(), v.clone()))
|
||||
cache.kvs[block_idx] = Some((k.clone(), v.clone()))
|
||||
}
|
||||
|
||||
let k = self.repeat_kv(k)?;
|
||||
@ -66,7 +71,7 @@ impl CausalSelfAttention {
|
||||
let v = v.transpose(1, 2)?.contiguous()?;
|
||||
|
||||
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
|
||||
let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?;
|
||||
let mask = cache.mask(seq_len)?.broadcast_as(att.shape())?;
|
||||
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
|
||||
let att = candle_nn::ops::softmax(&att, D::Minus1)?;
|
||||
// Convert to contiguous as matmul doesn't support strided vs for now.
|
||||
@ -90,7 +95,7 @@ impl CausalSelfAttention {
|
||||
}
|
||||
}
|
||||
|
||||
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let size_in = cfg.dim;
|
||||
let size_q = (cfg.dim / cfg.n_heads) * cfg.n_heads;
|
||||
let size_kv = (cfg.dim / cfg.n_heads) * cfg.n_kv_heads;
|
||||
@ -106,7 +111,6 @@ impl CausalSelfAttention {
|
||||
n_head: cfg.n_heads,
|
||||
n_key_value_head: cfg.n_kv_heads,
|
||||
head_dim: cfg.dim / cfg.n_heads,
|
||||
cache: cache.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -118,6 +122,7 @@ fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor>
|
||||
Ok(m)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Mlp {
|
||||
c_fc1: Linear,
|
||||
c_fc2: Linear,
|
||||
@ -148,6 +153,7 @@ impl Mlp {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Block {
|
||||
rms_1: RmsNorm,
|
||||
attn: CausalSelfAttention,
|
||||
@ -165,17 +171,23 @@ impl Block {
|
||||
}
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
|
||||
fn forward(
|
||||
&self,
|
||||
x: &Tensor,
|
||||
index_pos: usize,
|
||||
block_idx: usize,
|
||||
cache: &mut Cache,
|
||||
) -> Result<Tensor> {
|
||||
let residual = x;
|
||||
let x = self.rms_1.forward(x)?;
|
||||
let x = (self.attn.forward(&x, index_pos, block_idx)? + residual)?;
|
||||
let x = (self.attn.forward(&x, index_pos, block_idx, cache)? + residual)?;
|
||||
let residual = &x;
|
||||
let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?;
|
||||
Ok(x)
|
||||
}
|
||||
|
||||
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
|
||||
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg)?;
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cfg)?;
|
||||
let mlp = Mlp::load(vb.pp("mlp"), cfg)?;
|
||||
let input_layernorm = RmsNorm::new(cfg.dim, cfg.norm_eps, vb.pp("input_layernorm"))?;
|
||||
let post_attention_layernorm =
|
||||
@ -189,6 +201,7 @@ impl Block {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct QLlama {
|
||||
wte: Embedding,
|
||||
blocks: Vec<Block>,
|
||||
@ -198,23 +211,23 @@ pub struct QLlama {
|
||||
}
|
||||
|
||||
impl QLlama {
|
||||
pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||
pub fn forward(&self, x: &Tensor, index_pos: usize, cache: &mut Cache) -> Result<Tensor> {
|
||||
let (_b_sz, _seq_len) = x.dims2()?;
|
||||
let mut x = self.wte.forward(x)?;
|
||||
for (block_idx, block) in self.blocks.iter().enumerate() {
|
||||
x = block.forward(&x, index_pos, block_idx)?;
|
||||
x = block.forward(&x, index_pos, block_idx, cache)?;
|
||||
}
|
||||
let x = self.ln_f.forward(&x)?;
|
||||
let logits = self.lm_head.forward(&x)?;
|
||||
logits.to_dtype(DType::F32)
|
||||
}
|
||||
|
||||
pub fn load(vb: VarBuilder, cache: &Cache, cfg: Config) -> Result<Self> {
|
||||
pub fn load(vb: VarBuilder, cfg: Config) -> Result<Self> {
|
||||
let wte = Embedding::new(cfg.vocab_size, cfg.dim, vb.pp("model.embed_tokens"))?;
|
||||
let lm_head = linear(cfg.dim, cfg.vocab_size, vb.pp("lm_head"))?;
|
||||
let ln_f = RmsNorm::new(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?;
|
||||
let blocks: Vec<_> = (0..cfg.n_layers)
|
||||
.map(|i| Block::load(vb.pp(format!("model.layers.{i}")), cache, &cfg).unwrap())
|
||||
.map(|i| Block::load(vb.pp(format!("model.layers.{i}")), &cfg).unwrap())
|
||||
.collect();
|
||||
Ok(Self {
|
||||
wte,
|
||||
|
Reference in New Issue
Block a user