mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00

* Start adding llama2.c. * Model loading. * Add the llama-v2 model. * Start converting the weights. * Rotary embedding tweaks. * Get the model to generate some tokens.
319 lines
11 KiB
Rust
319 lines
11 KiB
Rust
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
|
use candle_nn::{Embedding, Linear, VarBuilder};
|
|
use std::collections::HashMap;
|
|
use std::sync::{Arc, Mutex};
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct Config {
|
|
pub dim: usize, // transformer dimension
|
|
pub hidden_dim: usize, // for ffn layers
|
|
pub n_layers: usize, // number of layers
|
|
pub n_heads: usize, // number of query heads
|
|
pub n_kv_heads: usize, // number of key/value heads (can be < query heads because of multiquery)
|
|
pub vocab_size: usize, // vocabulary size, usually 256 (byte-level)
|
|
pub seq_len: usize, // max sequence length
|
|
pub norm_eps: f64,
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
pub struct Cache {
|
|
masks: Arc<Mutex<HashMap<usize, Tensor>>>,
|
|
pub use_kv_cache: bool,
|
|
#[allow(clippy::type_complexity)]
|
|
kvs: Arc<Mutex<Vec<Option<(Tensor, Tensor)>>>>,
|
|
cos: Tensor,
|
|
sin: Tensor,
|
|
device: Device,
|
|
}
|
|
|
|
impl Cache {
|
|
pub fn new(use_kv_cache: bool, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
|
let freq_cis_real = vb.get((cfg.seq_len, cfg.head_size() / 2), "freq_cis_real")?;
|
|
let freq_cis_imag = vb.get((cfg.seq_len, cfg.head_size() / 2), "freq_cis_imag")?;
|
|
Ok(Self {
|
|
masks: Arc::new(Mutex::new(HashMap::new())),
|
|
use_kv_cache,
|
|
kvs: Arc::new(Mutex::new(vec![None; cfg.n_layers])),
|
|
cos: freq_cis_real,
|
|
sin: freq_cis_imag,
|
|
device: vb.device().clone(),
|
|
})
|
|
}
|
|
|
|
fn mask(&self, t: usize) -> Result<Tensor> {
|
|
let mut masks = self.masks.lock().unwrap();
|
|
if let Some(mask) = masks.get(&t) {
|
|
Ok(mask.clone())
|
|
} else {
|
|
// TODO: If we support bool or u8 tensors, this would be better.
|
|
let mask: Vec<_> = (0..t)
|
|
.flat_map(|i| (0..t).map(move |j| u32::from(j > i)))
|
|
.collect();
|
|
let mask = Tensor::from_slice(&mask, (t, t), &self.device)?;
|
|
masks.insert(t, mask.clone());
|
|
Ok(mask)
|
|
}
|
|
}
|
|
}
|
|
|
|
fn silu(xs: &Tensor) -> Result<Tensor> {
|
|
xs / (xs.neg()?.exp()? + 1.0)?
|
|
}
|
|
|
|
fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
|
|
let weight = vb.get((size2, size1), "weight")?;
|
|
Ok(Linear::new(weight, None))
|
|
}
|
|
|
|
fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> {
|
|
let embeddings = vb.get((cfg.vocab_size, cfg.dim), "weight")?;
|
|
Ok(Embedding::new(embeddings, cfg.dim))
|
|
}
|
|
|
|
struct RmsNorm {
|
|
scale: Tensor,
|
|
eps: f64,
|
|
}
|
|
|
|
impl RmsNorm {
|
|
fn load(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
|
|
let scale = vb.get(size, "weight")?;
|
|
Ok(Self { scale, eps })
|
|
}
|
|
|
|
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
|
let (b_sz, seq_len, hidden_size) = x.dims3()?;
|
|
let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
|
|
let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?;
|
|
let x_normed = (x / (norm_x + self.eps)?.sqrt()?)?;
|
|
let size = self.scale.dims1()?;
|
|
let scale = self
|
|
.scale
|
|
.to_dtype(DType::F32)?
|
|
.broadcast_as((b_sz, seq_len, size))?;
|
|
let x = (scale * x_normed)?;
|
|
Ok(x)
|
|
}
|
|
}
|
|
|
|
struct CausalSelfAttention {
|
|
q_proj: Linear,
|
|
k_proj: Linear,
|
|
v_proj: Linear,
|
|
o_proj: Linear,
|
|
n_head: usize,
|
|
n_key_value_head: usize,
|
|
head_dim: usize,
|
|
cache: Cache,
|
|
max_seq_len: usize,
|
|
}
|
|
|
|
impl CausalSelfAttention {
|
|
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
|
let (b_sz, _, seq_len, n_embd) = x.dims4()?;
|
|
let cos = self.cache.cos.narrow(0, index_pos, seq_len)?;
|
|
let sin = self.cache.sin.narrow(0, index_pos, seq_len)?;
|
|
let cos = cos.broadcast_as((b_sz, 1, seq_len, n_embd / 2))?;
|
|
let sin = sin.broadcast_as((b_sz, 1, seq_len, n_embd / 2))?;
|
|
let x0 = x.narrow(D::Minus1, 0, n_embd / 2)?;
|
|
let x1 = x.narrow(D::Minus1, n_embd / 2, n_embd / 2)?;
|
|
let dst0 = (x0.broadcast_mul(&cos)? - x1.broadcast_mul(&sin)?)?;
|
|
let dst1 = (x0.broadcast_mul(&sin)? + x1.broadcast_mul(&cos)?)?;
|
|
let rope = Tensor::cat(&[&dst0, &dst1], D::Minus1)?;
|
|
Ok(rope)
|
|
}
|
|
|
|
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
|
|
let (b_sz, seq_len, n_embd) = x.dims3()?;
|
|
let q = self.q_proj.forward(x)?;
|
|
let k = self.k_proj.forward(x)?;
|
|
let v = self.v_proj.forward(x)?;
|
|
|
|
let q = q.reshape((b_sz, seq_len, self.n_head, self.head_dim))?;
|
|
let k = k.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?;
|
|
let mut v = v.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?;
|
|
|
|
let q = self.apply_rotary_emb(&q, index_pos)?;
|
|
let mut k = self.apply_rotary_emb(&k, index_pos)?;
|
|
|
|
if self.cache.use_kv_cache {
|
|
let mut cache = self.cache.kvs.lock().unwrap();
|
|
if let Some((cache_k, cache_v)) = &cache[block_idx] {
|
|
k = Tensor::cat(&[cache_k, &k], 1)?.contiguous()?;
|
|
v = Tensor::cat(&[cache_v, &v], 1)?.contiguous()?;
|
|
}
|
|
cache[block_idx] = Some((k.clone(), v.clone()))
|
|
}
|
|
|
|
let k = self.repeat_kv(k)?;
|
|
let v = self.repeat_kv(v)?;
|
|
|
|
let q = q.transpose(1, 2)?.contiguous()?;
|
|
let k = k.transpose(1, 2)?.contiguous()?;
|
|
let v = v.transpose(1, 2)?.contiguous()?;
|
|
|
|
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
|
|
let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?;
|
|
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
|
|
let att = att.softmax(D::Minus1)?;
|
|
// Convert to contiguous as matmul doesn't support strided vs for now.
|
|
let y = att.matmul(&v.contiguous()?)?;
|
|
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
|
|
let y = self.o_proj.forward(&y)?;
|
|
Ok(y)
|
|
}
|
|
|
|
fn repeat_kv(&self, x: Tensor) -> Result<Tensor> {
|
|
let n_rep = self.n_head / self.n_key_value_head;
|
|
if n_rep == 1 {
|
|
Ok(x)
|
|
} else {
|
|
let (b_sz, seq_len, n_kv_head, head_dim) = x.dims4()?;
|
|
let x = x
|
|
.unsqueeze(3)?
|
|
.expand((b_sz, seq_len, n_kv_head, n_rep, head_dim))?
|
|
.reshape((b_sz, seq_len, n_kv_head * n_rep, head_dim))?;
|
|
Ok(x)
|
|
}
|
|
}
|
|
|
|
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
|
|
let size_in = cfg.dim;
|
|
let size_q = (cfg.dim / cfg.n_heads) * cfg.n_heads;
|
|
let size_kv = (cfg.dim / cfg.n_heads) * cfg.n_kv_heads;
|
|
let q_proj = linear(size_in, size_q, vb.pp("q_proj"))?;
|
|
let k_proj = linear(size_in, size_kv, vb.pp("k_proj"))?;
|
|
let v_proj = linear(size_in, size_kv, vb.pp("v_proj"))?;
|
|
let o_proj = linear(size_q, size_in, vb.pp("o_proj"))?;
|
|
Ok(Self {
|
|
q_proj,
|
|
k_proj,
|
|
v_proj,
|
|
o_proj,
|
|
n_head: cfg.n_heads,
|
|
n_key_value_head: cfg.n_kv_heads,
|
|
head_dim: cfg.dim / cfg.n_heads,
|
|
cache: cache.clone(),
|
|
max_seq_len: cfg.seq_len,
|
|
})
|
|
}
|
|
}
|
|
|
|
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
|
|
let shape = mask.shape();
|
|
let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
|
|
let m = mask.where_cond(&on_true, on_false)?;
|
|
Ok(m)
|
|
}
|
|
|
|
struct Mlp {
|
|
c_fc1: Linear,
|
|
c_fc2: Linear,
|
|
c_proj: Linear,
|
|
}
|
|
|
|
impl Mlp {
|
|
fn new(c_fc1: Linear, c_fc2: Linear, c_proj: Linear) -> Self {
|
|
Self {
|
|
c_fc1,
|
|
c_fc2,
|
|
c_proj,
|
|
}
|
|
}
|
|
|
|
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
|
let x = (silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?;
|
|
self.c_proj.forward(&x)
|
|
}
|
|
|
|
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
|
let h_size = cfg.dim;
|
|
let i_size = cfg.hidden_dim;
|
|
let c_fc1 = linear(h_size, i_size, vb.pp("gate_proj"))?;
|
|
let c_fc2 = linear(h_size, i_size, vb.pp("up_proj"))?;
|
|
let c_proj = linear(i_size, h_size, vb.pp("down_proj"))?;
|
|
Ok(Self::new(c_fc1, c_fc2, c_proj))
|
|
}
|
|
}
|
|
|
|
struct Block {
|
|
rms_1: RmsNorm,
|
|
attn: CausalSelfAttention,
|
|
rms_2: RmsNorm,
|
|
mlp: Mlp,
|
|
}
|
|
|
|
impl Block {
|
|
fn new(rms_1: RmsNorm, attn: CausalSelfAttention, rms_2: RmsNorm, mlp: Mlp) -> Self {
|
|
Self {
|
|
rms_1,
|
|
attn,
|
|
rms_2,
|
|
mlp,
|
|
}
|
|
}
|
|
|
|
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
|
|
let residual = x;
|
|
let x = self.rms_1.forward(x)?;
|
|
let x = (self.attn.forward(&x, index_pos, block_idx)? + residual)?;
|
|
let residual = &x;
|
|
let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?;
|
|
Ok(x)
|
|
}
|
|
|
|
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
|
|
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg)?;
|
|
let mlp = Mlp::load(vb.pp("mlp"), cfg)?;
|
|
let input_layernorm = RmsNorm::load(cfg.dim, cfg.norm_eps, vb.pp("input_layernorm"))?;
|
|
let post_attention_layernorm =
|
|
RmsNorm::load(cfg.dim, cfg.norm_eps, vb.pp("post_attention_layernorm"))?;
|
|
Ok(Self::new(
|
|
input_layernorm,
|
|
attn,
|
|
post_attention_layernorm,
|
|
mlp,
|
|
))
|
|
}
|
|
}
|
|
|
|
pub struct Llama {
|
|
wte: Embedding,
|
|
blocks: Vec<Block>,
|
|
ln_f: RmsNorm,
|
|
lm_head: Linear,
|
|
}
|
|
|
|
impl Llama {
|
|
fn new(wte: Embedding, blocks: Vec<Block>, ln_f: RmsNorm, lm_head: Linear) -> Self {
|
|
Self {
|
|
wte,
|
|
blocks,
|
|
ln_f,
|
|
lm_head,
|
|
}
|
|
}
|
|
|
|
pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
|
let (_b_sz, seq_len) = x.dims2()?;
|
|
let mut x = self.wte.forward(x)?;
|
|
for (block_idx, block) in self.blocks.iter().enumerate() {
|
|
x = block.forward(&x, index_pos, block_idx)?;
|
|
}
|
|
let x = self.ln_f.forward(&x)?;
|
|
let x = x.i((.., seq_len - 1, ..))?;
|
|
let logits = self.lm_head.forward(&x)?;
|
|
logits.to_dtype(DType::F32)
|
|
}
|
|
|
|
pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
|
|
let wte = embedding(cfg, vb.pp("model.embed_tokens"))?;
|
|
let lm_head = linear(cfg.dim, cfg.vocab_size, vb.pp("lm_head"))?;
|
|
let norm = RmsNorm::load(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?;
|
|
let blocks: Vec<_> = (0..cfg.n_layers)
|
|
.map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), cache, cfg).unwrap())
|
|
.collect();
|
|
Ok(Self::new(wte, blocks, norm, lm_head))
|
|
}
|
|
}
|