mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Explicit caching in llama2.c.
This commit is contained in:
@ -19,7 +19,7 @@ use candle_transformers::generation::LogitsProcessor;
|
|||||||
use std::io::Write;
|
use std::io::Write;
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
use model::{Config, Llama};
|
use model::{Cache, Config, Llama};
|
||||||
use qmodel::QLlama;
|
use qmodel::QLlama;
|
||||||
use weights::TransformerWeights;
|
use weights::TransformerWeights;
|
||||||
|
|
||||||
@ -160,10 +160,10 @@ enum Model {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Model {
|
impl Model {
|
||||||
fn forward(&self, xs: &Tensor, pos: usize) -> anyhow::Result<Tensor> {
|
fn forward(&self, xs: &Tensor, pos: usize, cache: &mut Cache) -> anyhow::Result<Tensor> {
|
||||||
match self {
|
match self {
|
||||||
Self::Llama(l) => Ok(l.forward(xs, pos)?),
|
Self::Llama(l) => Ok(l.forward(xs, pos, cache)?),
|
||||||
Self::QLlama(l) => Ok(l.forward(xs, pos)?),
|
Self::QLlama(l) => Ok(l.forward(xs, pos, cache)?),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -188,8 +188,8 @@ fn run_eval(args: &EvaluationCmd, common_args: &Args) -> Result<()> {
|
|||||||
let config = Config::from_reader(&mut file)?;
|
let config = Config::from_reader(&mut file)?;
|
||||||
let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
|
let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
|
||||||
let vb = weights.var_builder(&config, &device)?;
|
let vb = weights.var_builder(&config, &device)?;
|
||||||
let cache = model::Cache::new(false, &config, vb.pp("rot"))?;
|
let mut cache = Cache::new(false, &config, vb.pp("rot"))?;
|
||||||
let model = Llama::load(vb, &cache, config)?;
|
let model = Llama::load(vb, config)?;
|
||||||
|
|
||||||
let tokens = match &args.pretokenized_dir {
|
let tokens = match &args.pretokenized_dir {
|
||||||
None => {
|
None => {
|
||||||
@ -235,7 +235,7 @@ fn run_eval(args: &EvaluationCmd, common_args: &Args) -> Result<()> {
|
|||||||
let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
|
let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
|
||||||
for inp_tgt in batch_iter {
|
for inp_tgt in batch_iter {
|
||||||
let (inp, tgt) = inp_tgt?;
|
let (inp, tgt) = inp_tgt?;
|
||||||
let logits = model.forward(&inp, 0)?;
|
let logits = model.forward(&inp, 0, &mut cache)?;
|
||||||
let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?;
|
let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?;
|
||||||
println!("{}", loss.to_vec0::<f32>()?);
|
println!("{}", loss.to_vec0::<f32>()?);
|
||||||
}
|
}
|
||||||
@ -261,7 +261,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
|||||||
let is_safetensors = config_path
|
let is_safetensors = config_path
|
||||||
.extension()
|
.extension()
|
||||||
.map_or(false, |v| v == "safetensors");
|
.map_or(false, |v| v == "safetensors");
|
||||||
let (model, config) = if is_gguf {
|
let (model, config, mut cache) = if is_gguf {
|
||||||
let vb = qmodel::VarBuilder::from_gguf(config_path, &device)?;
|
let vb = qmodel::VarBuilder::from_gguf(config_path, &device)?;
|
||||||
let (_vocab_size, dim) = vb
|
let (_vocab_size, dim) = vb
|
||||||
.get_no_shape("model.embed_tokens.weight")?
|
.get_no_shape("model.embed_tokens.weight")?
|
||||||
@ -298,15 +298,15 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
|||||||
&device,
|
&device,
|
||||||
);
|
);
|
||||||
let cache = model::Cache::new(true, &config, fake_vb)?;
|
let cache = model::Cache::new(true, &config, fake_vb)?;
|
||||||
let model = Model::QLlama(QLlama::load(vb, &cache, config.clone())?);
|
let model = Model::QLlama(QLlama::load(vb, config.clone())?);
|
||||||
(model, config)
|
(model, config, cache)
|
||||||
} else if is_safetensors {
|
} else if is_safetensors {
|
||||||
let config = Config::tiny_15m();
|
let config = Config::tiny_15m();
|
||||||
let tensors = candle::safetensors::load(config_path, &device)?;
|
let tensors = candle::safetensors::load(config_path, &device)?;
|
||||||
let vb = candle_nn::VarBuilder::from_tensors(tensors, candle::DType::F32, &device);
|
let vb = candle_nn::VarBuilder::from_tensors(tensors, candle::DType::F32, &device);
|
||||||
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
|
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
|
||||||
let model = Model::Llama(Llama::load(vb, &cache, config.clone())?);
|
let model = Model::Llama(Llama::load(vb, config.clone())?);
|
||||||
(model, config)
|
(model, config, cache)
|
||||||
} else {
|
} else {
|
||||||
let mut file = std::fs::File::open(config_path)?;
|
let mut file = std::fs::File::open(config_path)?;
|
||||||
let config = Config::from_reader(&mut file)?;
|
let config = Config::from_reader(&mut file)?;
|
||||||
@ -314,8 +314,8 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
|||||||
let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
|
let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
|
||||||
let vb = weights.var_builder(&config, &device)?;
|
let vb = weights.var_builder(&config, &device)?;
|
||||||
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
|
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
|
||||||
let model = Model::Llama(Llama::load(vb, &cache, config.clone())?);
|
let model = Model::Llama(Llama::load(vb, config.clone())?);
|
||||||
(model, config)
|
(model, config, cache)
|
||||||
};
|
};
|
||||||
|
|
||||||
println!("starting the inference loop");
|
println!("starting the inference loop");
|
||||||
@ -338,7 +338,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
|||||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||||
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
|
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
|
||||||
let logits = model.forward(&input, index_pos)?;
|
let logits = model.forward(&input, index_pos, &mut cache)?;
|
||||||
let logits = logits.i((0, logits.dim(1)? - 1))?;
|
let logits = logits.i((0, logits.dim(1)? - 1))?;
|
||||||
let logits = if common_args.repeat_penalty == 1. || tokens.is_empty() {
|
let logits = if common_args.repeat_penalty == 1. || tokens.is_empty() {
|
||||||
logits
|
logits
|
||||||
|
@ -8,6 +8,7 @@ fn valid_loss(
|
|||||||
model: &Llama,
|
model: &Llama,
|
||||||
args: &crate::TrainingCmd,
|
args: &crate::TrainingCmd,
|
||||||
device: &Device,
|
device: &Device,
|
||||||
|
cache: &mut Cache,
|
||||||
) -> Result<f64> {
|
) -> Result<f64> {
|
||||||
let iter = DatasetRandomIter::new(dataset, true, model.config.seq_len, device.clone());
|
let iter = DatasetRandomIter::new(dataset, true, model.config.seq_len, device.clone());
|
||||||
let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
|
let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
|
||||||
@ -15,7 +16,7 @@ fn valid_loss(
|
|||||||
let mut cnt = 0usize;
|
let mut cnt = 0usize;
|
||||||
for inp_tgt in batch_iter.take(50) {
|
for inp_tgt in batch_iter.take(50) {
|
||||||
let (inp, tgt) = inp_tgt?;
|
let (inp, tgt) = inp_tgt?;
|
||||||
let logits = model.forward(&inp, 0)?;
|
let logits = model.forward(&inp, 0, cache)?;
|
||||||
let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?;
|
let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?;
|
||||||
sum_ce += loss.to_vec0::<f32>()? as f64;
|
sum_ce += loss.to_vec0::<f32>()? as f64;
|
||||||
cnt += 1;
|
cnt += 1;
|
||||||
@ -37,8 +38,8 @@ pub fn run(args: &crate::TrainingCmd, common_args: &crate::Args) -> Result<()> {
|
|||||||
let iter = DatasetRandomIter::new(&dataset, false, config.seq_len, device.clone());
|
let iter = DatasetRandomIter::new(&dataset, false, config.seq_len, device.clone());
|
||||||
let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
|
let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
|
||||||
|
|
||||||
let cache = Cache::new(false, &config, vb.pp("rot"))?;
|
let mut cache = Cache::new(false, &config, vb.pp("rot"))?;
|
||||||
let model = Llama::load(vb, &cache, config)?;
|
let model = Llama::load(vb, config)?;
|
||||||
let params = candle_nn::ParamsAdamW {
|
let params = candle_nn::ParamsAdamW {
|
||||||
lr: args.learning_rate,
|
lr: args.learning_rate,
|
||||||
..Default::default()
|
..Default::default()
|
||||||
@ -46,14 +47,14 @@ pub fn run(args: &crate::TrainingCmd, common_args: &crate::Args) -> Result<()> {
|
|||||||
let mut opt = candle_nn::AdamW::new(varmap.all_vars(), params)?;
|
let mut opt = candle_nn::AdamW::new(varmap.all_vars(), params)?;
|
||||||
for (batch_index, batch) in batch_iter.enumerate() {
|
for (batch_index, batch) in batch_iter.enumerate() {
|
||||||
let (inp, tgt) = batch?;
|
let (inp, tgt) = batch?;
|
||||||
let logits = model.forward(&inp, 0)?;
|
let logits = model.forward(&inp, 0, &mut cache)?;
|
||||||
let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?;
|
let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?;
|
||||||
opt.backward_step(&loss)?;
|
opt.backward_step(&loss)?;
|
||||||
|
|
||||||
if batch_index > 0 && batch_index % 100 == 0 {
|
if batch_index > 0 && batch_index % 100 == 0 {
|
||||||
// TODO: Add a way to deactivate the backprop graph tracking when computing the
|
// TODO: Add a way to deactivate the backprop graph tracking when computing the
|
||||||
// validation loss.
|
// validation loss.
|
||||||
let loss = valid_loss(&dataset, &model, args, &device)?;
|
let loss = valid_loss(&dataset, &model, args, &device, &mut cache)?;
|
||||||
println!("{batch_index} {loss}");
|
println!("{batch_index} {loss}");
|
||||||
}
|
}
|
||||||
if batch_index > 0 && batch_index % 1000 == 0 {
|
if batch_index > 0 && batch_index % 1000 == 0 {
|
||||||
|
@ -2,7 +2,6 @@ use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
|||||||
use candle_nn::linear_no_bias as linear;
|
use candle_nn::linear_no_bias as linear;
|
||||||
use candle_nn::{embedding, rms_norm, Embedding, Linear, Module, RmsNorm, VarBuilder};
|
use candle_nn::{embedding, rms_norm, Embedding, Linear, Module, RmsNorm, VarBuilder};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::{Arc, Mutex};
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct Config {
|
pub struct Config {
|
||||||
@ -70,12 +69,11 @@ impl Config {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct Cache {
|
pub struct Cache {
|
||||||
masks: Arc<Mutex<HashMap<usize, Tensor>>>,
|
masks: HashMap<usize, Tensor>,
|
||||||
pub use_kv_cache: bool,
|
pub use_kv_cache: bool,
|
||||||
#[allow(clippy::type_complexity)]
|
pub kvs: Vec<Option<(Tensor, Tensor)>>,
|
||||||
pub kvs: Arc<Mutex<Vec<Option<(Tensor, Tensor)>>>>,
|
|
||||||
pub cos: Tensor,
|
pub cos: Tensor,
|
||||||
pub sin: Tensor,
|
pub sin: Tensor,
|
||||||
device: Device,
|
device: Device,
|
||||||
@ -105,25 +103,24 @@ impl Cache {
|
|||||||
let cos = freq_cis_real.reshape((cfg.seq_len, cfg.head_size() / 2, 1))?;
|
let cos = freq_cis_real.reshape((cfg.seq_len, cfg.head_size() / 2, 1))?;
|
||||||
let sin = freq_cis_imag.reshape((cfg.seq_len, cfg.head_size() / 2, 1))?;
|
let sin = freq_cis_imag.reshape((cfg.seq_len, cfg.head_size() / 2, 1))?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
masks: Arc::new(Mutex::new(HashMap::new())),
|
masks: HashMap::new(),
|
||||||
use_kv_cache,
|
use_kv_cache,
|
||||||
kvs: Arc::new(Mutex::new(vec![None; cfg.n_layers])),
|
kvs: vec![None; cfg.n_layers],
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
device: vb.device().clone(),
|
device: vb.device().clone(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn mask(&self, t: usize) -> Result<Tensor> {
|
pub fn mask(&mut self, t: usize) -> Result<Tensor> {
|
||||||
let mut masks = self.masks.lock().unwrap();
|
if let Some(mask) = self.masks.get(&t) {
|
||||||
if let Some(mask) = masks.get(&t) {
|
|
||||||
Ok(mask.clone())
|
Ok(mask.clone())
|
||||||
} else {
|
} else {
|
||||||
let mask: Vec<_> = (0..t)
|
let mask: Vec<_> = (0..t)
|
||||||
.flat_map(|i| (0..t).map(move |j| u8::from(j > i)))
|
.flat_map(|i| (0..t).map(move |j| u8::from(j > i)))
|
||||||
.collect();
|
.collect();
|
||||||
let mask = Tensor::from_slice(&mask, (t, t), &self.device)?;
|
let mask = Tensor::from_slice(&mask, (t, t), &self.device)?;
|
||||||
masks.insert(t, mask.clone());
|
self.masks.insert(t, mask.clone());
|
||||||
Ok(mask)
|
Ok(mask)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -133,6 +130,7 @@ fn silu(xs: &Tensor) -> Result<Tensor> {
|
|||||||
xs / (xs.neg()?.exp()? + 1.0)?
|
xs / (xs.neg()?.exp()? + 1.0)?
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
struct CausalSelfAttention {
|
struct CausalSelfAttention {
|
||||||
q_proj: Linear,
|
q_proj: Linear,
|
||||||
k_proj: Linear,
|
k_proj: Linear,
|
||||||
@ -141,14 +139,13 @@ struct CausalSelfAttention {
|
|||||||
n_head: usize,
|
n_head: usize,
|
||||||
n_key_value_head: usize,
|
n_key_value_head: usize,
|
||||||
head_dim: usize,
|
head_dim: usize,
|
||||||
cache: Cache,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CausalSelfAttention {
|
impl CausalSelfAttention {
|
||||||
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize, cache: &Cache) -> Result<Tensor> {
|
||||||
let (b_sz, seq_len, h, n_embd) = x.dims4()?;
|
let (b_sz, seq_len, h, n_embd) = x.dims4()?;
|
||||||
let cos = self.cache.cos.i(index_pos..index_pos + seq_len)?;
|
let cos = cache.cos.i(index_pos..index_pos + seq_len)?;
|
||||||
let sin = self.cache.sin.i(index_pos..index_pos + seq_len)?;
|
let sin = cache.sin.i(index_pos..index_pos + seq_len)?;
|
||||||
let cos = cos.unsqueeze(1)?;
|
let cos = cos.unsqueeze(1)?;
|
||||||
let sin = sin.unsqueeze(1)?;
|
let sin = sin.unsqueeze(1)?;
|
||||||
let cos = cos.broadcast_as((b_sz, seq_len, 1, n_embd / 2, 1))?;
|
let cos = cos.broadcast_as((b_sz, seq_len, 1, n_embd / 2, 1))?;
|
||||||
@ -162,7 +159,13 @@ impl CausalSelfAttention {
|
|||||||
Ok(rope)
|
Ok(rope)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
|
fn forward(
|
||||||
|
&self,
|
||||||
|
x: &Tensor,
|
||||||
|
index_pos: usize,
|
||||||
|
block_idx: usize,
|
||||||
|
cache: &mut Cache,
|
||||||
|
) -> Result<Tensor> {
|
||||||
let (b_sz, seq_len, n_embd) = x.dims3()?;
|
let (b_sz, seq_len, n_embd) = x.dims3()?;
|
||||||
let q = self.q_proj.forward(x)?;
|
let q = self.q_proj.forward(x)?;
|
||||||
let k = self.k_proj.forward(x)?;
|
let k = self.k_proj.forward(x)?;
|
||||||
@ -172,16 +175,15 @@ impl CausalSelfAttention {
|
|||||||
let k = k.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?;
|
let k = k.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?;
|
||||||
let mut v = v.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?;
|
let mut v = v.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?;
|
||||||
|
|
||||||
let q = self.apply_rotary_emb(&q, index_pos)?;
|
let q = self.apply_rotary_emb(&q, index_pos, cache)?;
|
||||||
let mut k = self.apply_rotary_emb(&k, index_pos)?;
|
let mut k = self.apply_rotary_emb(&k, index_pos, cache)?;
|
||||||
|
|
||||||
if self.cache.use_kv_cache {
|
if cache.use_kv_cache {
|
||||||
let mut cache = self.cache.kvs.lock().unwrap();
|
if let Some((cache_k, cache_v)) = &cache.kvs[block_idx] {
|
||||||
if let Some((cache_k, cache_v)) = &cache[block_idx] {
|
|
||||||
k = Tensor::cat(&[cache_k, &k], 1)?.contiguous()?;
|
k = Tensor::cat(&[cache_k, &k], 1)?.contiguous()?;
|
||||||
v = Tensor::cat(&[cache_v, &v], 1)?.contiguous()?;
|
v = Tensor::cat(&[cache_v, &v], 1)?.contiguous()?;
|
||||||
}
|
}
|
||||||
cache[block_idx] = Some((k.clone(), v.clone()))
|
cache.kvs[block_idx] = Some((k.clone(), v.clone()))
|
||||||
}
|
}
|
||||||
|
|
||||||
let k = self.repeat_kv(k)?;
|
let k = self.repeat_kv(k)?;
|
||||||
@ -192,7 +194,7 @@ impl CausalSelfAttention {
|
|||||||
let v = v.transpose(1, 2)?.contiguous()?;
|
let v = v.transpose(1, 2)?.contiguous()?;
|
||||||
|
|
||||||
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
|
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
|
||||||
let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?;
|
let mask = cache.mask(seq_len)?.broadcast_as(att.shape())?;
|
||||||
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
|
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
|
||||||
let att = candle_nn::ops::softmax(&att, D::Minus1)?;
|
let att = candle_nn::ops::softmax(&att, D::Minus1)?;
|
||||||
// Convert to contiguous as matmul doesn't support strided vs for now.
|
// Convert to contiguous as matmul doesn't support strided vs for now.
|
||||||
@ -216,7 +218,7 @@ impl CausalSelfAttention {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
|
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
let size_in = cfg.dim;
|
let size_in = cfg.dim;
|
||||||
let size_q = (cfg.dim / cfg.n_heads) * cfg.n_heads;
|
let size_q = (cfg.dim / cfg.n_heads) * cfg.n_heads;
|
||||||
let size_kv = (cfg.dim / cfg.n_heads) * cfg.n_kv_heads;
|
let size_kv = (cfg.dim / cfg.n_heads) * cfg.n_kv_heads;
|
||||||
@ -232,7 +234,6 @@ impl CausalSelfAttention {
|
|||||||
n_head: cfg.n_heads,
|
n_head: cfg.n_heads,
|
||||||
n_key_value_head: cfg.n_kv_heads,
|
n_key_value_head: cfg.n_kv_heads,
|
||||||
head_dim: cfg.dim / cfg.n_heads,
|
head_dim: cfg.dim / cfg.n_heads,
|
||||||
cache: cache.clone(),
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -244,6 +245,7 @@ fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor>
|
|||||||
Ok(m)
|
Ok(m)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
struct Mlp {
|
struct Mlp {
|
||||||
c_fc1: Linear,
|
c_fc1: Linear,
|
||||||
c_fc2: Linear,
|
c_fc2: Linear,
|
||||||
@ -274,6 +276,7 @@ impl Mlp {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
struct Block {
|
struct Block {
|
||||||
rms_1: RmsNorm,
|
rms_1: RmsNorm,
|
||||||
attn: CausalSelfAttention,
|
attn: CausalSelfAttention,
|
||||||
@ -291,17 +294,23 @@ impl Block {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
|
fn forward(
|
||||||
|
&self,
|
||||||
|
x: &Tensor,
|
||||||
|
index_pos: usize,
|
||||||
|
block_idx: usize,
|
||||||
|
cache: &mut Cache,
|
||||||
|
) -> Result<Tensor> {
|
||||||
let residual = x;
|
let residual = x;
|
||||||
let x = self.rms_1.forward(x)?;
|
let x = self.rms_1.forward(x)?;
|
||||||
let x = (self.attn.forward(&x, index_pos, block_idx)? + residual)?;
|
let x = (self.attn.forward(&x, index_pos, block_idx, cache)? + residual)?;
|
||||||
let residual = &x;
|
let residual = &x;
|
||||||
let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?;
|
let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?;
|
||||||
Ok(x)
|
Ok(x)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
|
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg)?;
|
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cfg)?;
|
||||||
let mlp = Mlp::load(vb.pp("mlp"), cfg)?;
|
let mlp = Mlp::load(vb.pp("mlp"), cfg)?;
|
||||||
let input_layernorm = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("input_layernorm"))?;
|
let input_layernorm = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("input_layernorm"))?;
|
||||||
let post_attention_layernorm =
|
let post_attention_layernorm =
|
||||||
@ -315,6 +324,7 @@ impl Block {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
pub struct Llama {
|
pub struct Llama {
|
||||||
wte: Embedding,
|
wte: Embedding,
|
||||||
blocks: Vec<Block>,
|
blocks: Vec<Block>,
|
||||||
@ -324,23 +334,23 @@ pub struct Llama {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Llama {
|
impl Llama {
|
||||||
pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
pub fn forward(&self, x: &Tensor, index_pos: usize, cache: &mut Cache) -> Result<Tensor> {
|
||||||
let (_b_sz, _seq_len) = x.dims2()?;
|
let (_b_sz, _seq_len) = x.dims2()?;
|
||||||
let mut x = self.wte.forward(x)?;
|
let mut x = self.wte.forward(x)?;
|
||||||
for (block_idx, block) in self.blocks.iter().enumerate() {
|
for (block_idx, block) in self.blocks.iter().enumerate() {
|
||||||
x = block.forward(&x, index_pos, block_idx)?;
|
x = block.forward(&x, index_pos, block_idx, cache)?;
|
||||||
}
|
}
|
||||||
let x = self.ln_f.forward(&x)?;
|
let x = self.ln_f.forward(&x)?;
|
||||||
let logits = self.lm_head.forward(&x)?;
|
let logits = self.lm_head.forward(&x)?;
|
||||||
logits.to_dtype(DType::F32)
|
logits.to_dtype(DType::F32)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn load(vb: VarBuilder, cache: &Cache, cfg: Config) -> Result<Self> {
|
pub fn load(vb: VarBuilder, cfg: Config) -> Result<Self> {
|
||||||
let wte = embedding(cfg.vocab_size, cfg.dim, vb.pp("model.embed_tokens"))?;
|
let wte = embedding(cfg.vocab_size, cfg.dim, vb.pp("model.embed_tokens"))?;
|
||||||
let lm_head = linear(cfg.dim, cfg.vocab_size, vb.pp("lm_head"))?;
|
let lm_head = linear(cfg.dim, cfg.vocab_size, vb.pp("lm_head"))?;
|
||||||
let ln_f = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?;
|
let ln_f = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?;
|
||||||
let blocks: Vec<_> = (0..cfg.n_layers)
|
let blocks: Vec<_> = (0..cfg.n_layers)
|
||||||
.map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), cache, &cfg).unwrap())
|
.map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), &cfg).unwrap())
|
||||||
.collect();
|
.collect();
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
wte,
|
wte,
|
||||||
|
@ -7,6 +7,7 @@ fn silu(xs: &Tensor) -> Result<Tensor> {
|
|||||||
xs / (xs.neg()?.exp()? + 1.0)?
|
xs / (xs.neg()?.exp()? + 1.0)?
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
struct CausalSelfAttention {
|
struct CausalSelfAttention {
|
||||||
q_proj: Linear,
|
q_proj: Linear,
|
||||||
k_proj: Linear,
|
k_proj: Linear,
|
||||||
@ -15,14 +16,13 @@ struct CausalSelfAttention {
|
|||||||
n_head: usize,
|
n_head: usize,
|
||||||
n_key_value_head: usize,
|
n_key_value_head: usize,
|
||||||
head_dim: usize,
|
head_dim: usize,
|
||||||
cache: Cache,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CausalSelfAttention {
|
impl CausalSelfAttention {
|
||||||
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize, cache: &Cache) -> Result<Tensor> {
|
||||||
let (b_sz, seq_len, h, n_embd) = x.dims4()?;
|
let (b_sz, seq_len, h, n_embd) = x.dims4()?;
|
||||||
let cos = self.cache.cos.i(index_pos..index_pos + seq_len)?;
|
let cos = cache.cos.i(index_pos..index_pos + seq_len)?;
|
||||||
let sin = self.cache.sin.i(index_pos..index_pos + seq_len)?;
|
let sin = cache.sin.i(index_pos..index_pos + seq_len)?;
|
||||||
let cos = cos.unsqueeze(1)?;
|
let cos = cos.unsqueeze(1)?;
|
||||||
let sin = sin.unsqueeze(1)?;
|
let sin = sin.unsqueeze(1)?;
|
||||||
let cos = cos.broadcast_as((b_sz, seq_len, 1, n_embd / 2, 1))?;
|
let cos = cos.broadcast_as((b_sz, seq_len, 1, n_embd / 2, 1))?;
|
||||||
@ -36,7 +36,13 @@ impl CausalSelfAttention {
|
|||||||
Ok(rope)
|
Ok(rope)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
|
fn forward(
|
||||||
|
&self,
|
||||||
|
x: &Tensor,
|
||||||
|
index_pos: usize,
|
||||||
|
block_idx: usize,
|
||||||
|
cache: &mut Cache,
|
||||||
|
) -> Result<Tensor> {
|
||||||
let (b_sz, seq_len, n_embd) = x.dims3()?;
|
let (b_sz, seq_len, n_embd) = x.dims3()?;
|
||||||
let q = self.q_proj.forward(x)?;
|
let q = self.q_proj.forward(x)?;
|
||||||
let k = self.k_proj.forward(x)?;
|
let k = self.k_proj.forward(x)?;
|
||||||
@ -46,16 +52,15 @@ impl CausalSelfAttention {
|
|||||||
let k = k.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?;
|
let k = k.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?;
|
||||||
let mut v = v.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?;
|
let mut v = v.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?;
|
||||||
|
|
||||||
let q = self.apply_rotary_emb(&q, index_pos)?;
|
let q = self.apply_rotary_emb(&q, index_pos, cache)?;
|
||||||
let mut k = self.apply_rotary_emb(&k, index_pos)?;
|
let mut k = self.apply_rotary_emb(&k, index_pos, cache)?;
|
||||||
|
|
||||||
if self.cache.use_kv_cache {
|
if cache.use_kv_cache {
|
||||||
let mut cache = self.cache.kvs.lock().unwrap();
|
if let Some((cache_k, cache_v)) = &cache.kvs[block_idx] {
|
||||||
if let Some((cache_k, cache_v)) = &cache[block_idx] {
|
|
||||||
k = Tensor::cat(&[cache_k, &k], 1)?.contiguous()?;
|
k = Tensor::cat(&[cache_k, &k], 1)?.contiguous()?;
|
||||||
v = Tensor::cat(&[cache_v, &v], 1)?.contiguous()?;
|
v = Tensor::cat(&[cache_v, &v], 1)?.contiguous()?;
|
||||||
}
|
}
|
||||||
cache[block_idx] = Some((k.clone(), v.clone()))
|
cache.kvs[block_idx] = Some((k.clone(), v.clone()))
|
||||||
}
|
}
|
||||||
|
|
||||||
let k = self.repeat_kv(k)?;
|
let k = self.repeat_kv(k)?;
|
||||||
@ -66,7 +71,7 @@ impl CausalSelfAttention {
|
|||||||
let v = v.transpose(1, 2)?.contiguous()?;
|
let v = v.transpose(1, 2)?.contiguous()?;
|
||||||
|
|
||||||
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
|
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
|
||||||
let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?;
|
let mask = cache.mask(seq_len)?.broadcast_as(att.shape())?;
|
||||||
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
|
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
|
||||||
let att = candle_nn::ops::softmax(&att, D::Minus1)?;
|
let att = candle_nn::ops::softmax(&att, D::Minus1)?;
|
||||||
// Convert to contiguous as matmul doesn't support strided vs for now.
|
// Convert to contiguous as matmul doesn't support strided vs for now.
|
||||||
@ -90,7 +95,7 @@ impl CausalSelfAttention {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
|
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
let size_in = cfg.dim;
|
let size_in = cfg.dim;
|
||||||
let size_q = (cfg.dim / cfg.n_heads) * cfg.n_heads;
|
let size_q = (cfg.dim / cfg.n_heads) * cfg.n_heads;
|
||||||
let size_kv = (cfg.dim / cfg.n_heads) * cfg.n_kv_heads;
|
let size_kv = (cfg.dim / cfg.n_heads) * cfg.n_kv_heads;
|
||||||
@ -106,7 +111,6 @@ impl CausalSelfAttention {
|
|||||||
n_head: cfg.n_heads,
|
n_head: cfg.n_heads,
|
||||||
n_key_value_head: cfg.n_kv_heads,
|
n_key_value_head: cfg.n_kv_heads,
|
||||||
head_dim: cfg.dim / cfg.n_heads,
|
head_dim: cfg.dim / cfg.n_heads,
|
||||||
cache: cache.clone(),
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -118,6 +122,7 @@ fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor>
|
|||||||
Ok(m)
|
Ok(m)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
struct Mlp {
|
struct Mlp {
|
||||||
c_fc1: Linear,
|
c_fc1: Linear,
|
||||||
c_fc2: Linear,
|
c_fc2: Linear,
|
||||||
@ -148,6 +153,7 @@ impl Mlp {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
struct Block {
|
struct Block {
|
||||||
rms_1: RmsNorm,
|
rms_1: RmsNorm,
|
||||||
attn: CausalSelfAttention,
|
attn: CausalSelfAttention,
|
||||||
@ -165,17 +171,23 @@ impl Block {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
|
fn forward(
|
||||||
|
&self,
|
||||||
|
x: &Tensor,
|
||||||
|
index_pos: usize,
|
||||||
|
block_idx: usize,
|
||||||
|
cache: &mut Cache,
|
||||||
|
) -> Result<Tensor> {
|
||||||
let residual = x;
|
let residual = x;
|
||||||
let x = self.rms_1.forward(x)?;
|
let x = self.rms_1.forward(x)?;
|
||||||
let x = (self.attn.forward(&x, index_pos, block_idx)? + residual)?;
|
let x = (self.attn.forward(&x, index_pos, block_idx, cache)? + residual)?;
|
||||||
let residual = &x;
|
let residual = &x;
|
||||||
let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?;
|
let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?;
|
||||||
Ok(x)
|
Ok(x)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
|
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg)?;
|
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cfg)?;
|
||||||
let mlp = Mlp::load(vb.pp("mlp"), cfg)?;
|
let mlp = Mlp::load(vb.pp("mlp"), cfg)?;
|
||||||
let input_layernorm = RmsNorm::new(cfg.dim, cfg.norm_eps, vb.pp("input_layernorm"))?;
|
let input_layernorm = RmsNorm::new(cfg.dim, cfg.norm_eps, vb.pp("input_layernorm"))?;
|
||||||
let post_attention_layernorm =
|
let post_attention_layernorm =
|
||||||
@ -189,6 +201,7 @@ impl Block {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
pub struct QLlama {
|
pub struct QLlama {
|
||||||
wte: Embedding,
|
wte: Embedding,
|
||||||
blocks: Vec<Block>,
|
blocks: Vec<Block>,
|
||||||
@ -198,23 +211,23 @@ pub struct QLlama {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl QLlama {
|
impl QLlama {
|
||||||
pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
pub fn forward(&self, x: &Tensor, index_pos: usize, cache: &mut Cache) -> Result<Tensor> {
|
||||||
let (_b_sz, _seq_len) = x.dims2()?;
|
let (_b_sz, _seq_len) = x.dims2()?;
|
||||||
let mut x = self.wte.forward(x)?;
|
let mut x = self.wte.forward(x)?;
|
||||||
for (block_idx, block) in self.blocks.iter().enumerate() {
|
for (block_idx, block) in self.blocks.iter().enumerate() {
|
||||||
x = block.forward(&x, index_pos, block_idx)?;
|
x = block.forward(&x, index_pos, block_idx, cache)?;
|
||||||
}
|
}
|
||||||
let x = self.ln_f.forward(&x)?;
|
let x = self.ln_f.forward(&x)?;
|
||||||
let logits = self.lm_head.forward(&x)?;
|
let logits = self.lm_head.forward(&x)?;
|
||||||
logits.to_dtype(DType::F32)
|
logits.to_dtype(DType::F32)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn load(vb: VarBuilder, cache: &Cache, cfg: Config) -> Result<Self> {
|
pub fn load(vb: VarBuilder, cfg: Config) -> Result<Self> {
|
||||||
let wte = Embedding::new(cfg.vocab_size, cfg.dim, vb.pp("model.embed_tokens"))?;
|
let wte = Embedding::new(cfg.vocab_size, cfg.dim, vb.pp("model.embed_tokens"))?;
|
||||||
let lm_head = linear(cfg.dim, cfg.vocab_size, vb.pp("lm_head"))?;
|
let lm_head = linear(cfg.dim, cfg.vocab_size, vb.pp("lm_head"))?;
|
||||||
let ln_f = RmsNorm::new(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?;
|
let ln_f = RmsNorm::new(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?;
|
||||||
let blocks: Vec<_> = (0..cfg.n_layers)
|
let blocks: Vec<_> = (0..cfg.n_layers)
|
||||||
.map(|i| Block::load(vb.pp(format!("model.layers.{i}")), cache, &cfg).unwrap())
|
.map(|i| Block::load(vb.pp(format!("model.layers.{i}")), &cfg).unwrap())
|
||||||
.collect();
|
.collect();
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
wte,
|
wte,
|
||||||
|
Reference in New Issue
Block a user