mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
@ -1,5 +1,5 @@
|
||||
use candle::{DType, Device, Result, Tensor, D};
|
||||
use candle_nn::{Linear, VarBuilder};
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{Embedding, Linear, VarBuilder};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
@ -71,23 +71,9 @@ fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
|
||||
Ok(Linear::new(weight, None))
|
||||
}
|
||||
|
||||
struct Embedding {
|
||||
embeddings: Tensor,
|
||||
}
|
||||
|
||||
impl Embedding {
|
||||
fn new(embeddings: Tensor) -> Self {
|
||||
Self { embeddings }
|
||||
}
|
||||
|
||||
fn load(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let embeddings = vb.get((cfg.vocab_size, cfg.hidden_size), "weight")?;
|
||||
Ok(Self::new(embeddings))
|
||||
}
|
||||
|
||||
fn forward(&self, indexes: &Tensor) -> Result<Tensor> {
|
||||
Tensor::embedding(indexes, &self.embeddings)
|
||||
}
|
||||
fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> {
|
||||
let embeddings = vb.get((cfg.vocab_size, cfg.hidden_size), "weight")?;
|
||||
Ok(Embedding::new(embeddings, cfg.hidden_size))
|
||||
}
|
||||
|
||||
struct RmsNorm {
|
||||
@ -108,15 +94,15 @@ impl RmsNorm {
|
||||
let in_dtype = x.dtype();
|
||||
// This is a no-op if x's dtype is already f32.
|
||||
let x = x.to_dtype(DType::F32)?;
|
||||
let (seq_len, hidden_size) = x.shape().r2()?;
|
||||
let norm_x = ((&x * &x)?.sum(&[1])? / hidden_size as f64)?;
|
||||
let norm_x = norm_x.broadcast_as((seq_len, hidden_size))?;
|
||||
let (b_sz, seq_len, hidden_size) = x.shape().r3()?;
|
||||
let norm_x = ((&x * &x)?.sum(&[2])? / hidden_size as f64)?;
|
||||
let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?;
|
||||
let x_normed = (x / (norm_x + 1e-5)?.sqrt()?)?;
|
||||
let size = self.scale.shape().r1()?;
|
||||
let scale = self
|
||||
.scale
|
||||
.to_dtype(DType::F32)?
|
||||
.broadcast_as((seq_len, size))?;
|
||||
.broadcast_as((b_sz, seq_len, size))?;
|
||||
let x = (scale * x_normed)?;
|
||||
let x = x.to_dtype(in_dtype)?;
|
||||
Ok(x)
|
||||
@ -143,8 +129,8 @@ impl CausalSelfAttention {
|
||||
fn apply_rotary_emb(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
|
||||
let mut dims = x.dims().to_vec();
|
||||
let fcis_dims = freqs_cis.dims();
|
||||
let freqs_cis = if dims[1] < fcis_dims[1] {
|
||||
freqs_cis.narrow(1, 0, dims[1])?
|
||||
let freqs_cis = if dims[2] < fcis_dims[1] {
|
||||
freqs_cis.narrow(1, 0, dims[2])?
|
||||
} else {
|
||||
freqs_cis.clone()
|
||||
};
|
||||
@ -169,35 +155,34 @@ impl CausalSelfAttention {
|
||||
|
||||
fn forward(&self, x: &Tensor, freqs_cis: &Tensor, block_idx: usize) -> Result<Tensor> {
|
||||
let x_dtype = x.dtype();
|
||||
let (t, c) = x.shape().r2()?;
|
||||
let (b_sz, seq_len, n_embd) = x.shape().r3()?;
|
||||
let qkv = self.c_attn.forward(x)?;
|
||||
let qkv = qkv.to_dtype(DType::F32)?;
|
||||
let n_embd = c;
|
||||
let q = qkv.narrow(1, 0, n_embd)?;
|
||||
let k = qkv.narrow(1, n_embd, n_embd)?;
|
||||
let v = qkv.narrow(1, 2 * n_embd, n_embd)?;
|
||||
let target_dim = [t, self.n_head, c / self.n_head];
|
||||
let k = k.reshape(target_dim.as_slice())?.transpose(0, 1)?;
|
||||
let q = q.reshape(target_dim.as_slice())?.transpose(0, 1)?;
|
||||
let mut v = v.reshape(target_dim.as_slice())?.transpose(0, 1)?;
|
||||
let q = qkv.narrow(D::Minus1, 0, n_embd)?;
|
||||
let k = qkv.narrow(D::Minus1, n_embd, n_embd)?;
|
||||
let v = qkv.narrow(D::Minus1, 2 * n_embd, n_embd)?;
|
||||
let target_dim = [b_sz, seq_len, self.n_head, n_embd / self.n_head];
|
||||
let k = k.reshape(target_dim.as_slice())?.transpose(1, 2)?;
|
||||
let q = q.reshape(target_dim.as_slice())?.transpose(1, 2)?;
|
||||
let mut v = v.reshape(target_dim.as_slice())?.transpose(1, 2)?;
|
||||
let q = self.apply_rotary_emb(&q, freqs_cis)?;
|
||||
let mut k = self.apply_rotary_emb(&k, freqs_cis)?;
|
||||
|
||||
if self.cache.use_kv_cache {
|
||||
let mut cache = self.cache.kvs.lock().unwrap();
|
||||
if let Some((cache_k, cache_v)) = &cache[block_idx] {
|
||||
k = Tensor::cat(&[cache_k, &k], 1)?.contiguous()?;
|
||||
v = Tensor::cat(&[cache_v, &v], 1)?.contiguous()?;
|
||||
k = Tensor::cat(&[cache_k, &k], 2)?.contiguous()?;
|
||||
v = Tensor::cat(&[cache_v, &v], 2)?.contiguous()?;
|
||||
let k_seq_len = k.dims()[1];
|
||||
if k_seq_len > MAX_SEQ_LEN {
|
||||
k = k
|
||||
.narrow(1, k_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)?
|
||||
.narrow(D::Minus1, k_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)?
|
||||
.contiguous()?
|
||||
}
|
||||
let v_seq_len = v.dims()[1];
|
||||
if v_seq_len > 2 * MAX_SEQ_LEN {
|
||||
v = v
|
||||
.narrow(1, v_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)?
|
||||
.narrow(D::Minus1, v_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)?
|
||||
.contiguous()?
|
||||
}
|
||||
}
|
||||
@ -205,12 +190,12 @@ impl CausalSelfAttention {
|
||||
}
|
||||
|
||||
let att = (q.matmul(&k.t()?)? / (k.dim(D::Minus1)? as f64).sqrt())?;
|
||||
let mask = self.cache.mask(t)?.broadcast_as(att.shape())?;
|
||||
let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?;
|
||||
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
|
||||
let att = att.softmax(D::Minus1)?;
|
||||
// Convert to contiguous as matmul doesn't support strided vs for now.
|
||||
let y = att.matmul(&v.contiguous()?)?;
|
||||
let y = y.transpose(0, 1)?.reshape(&[t, c])?;
|
||||
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
|
||||
let y = y.to_dtype(x_dtype)?;
|
||||
let y = self.c_proj.forward(&y)?;
|
||||
Ok(y)
|
||||
@ -336,23 +321,19 @@ impl Llama {
|
||||
}
|
||||
|
||||
pub fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
|
||||
// TODO: Support for mini-batches? (i.e. r2)
|
||||
let t = x.shape().r1()?;
|
||||
let (_b_sz, seq_len) = x.shape().r2()?;
|
||||
let mut x = self.wte.forward(x)?;
|
||||
for (block_idx, block) in self.blocks.iter().enumerate() {
|
||||
x = block.forward(&x, freqs_cis, block_idx)?;
|
||||
}
|
||||
let x = self.ln_f.forward(&x)?;
|
||||
let x = x.narrow(0, t - 1, 1)?;
|
||||
let x = x.i((.., seq_len - 1, ..))?;
|
||||
let logits = self.lm_head.forward(&x)?;
|
||||
let logits = logits.to_dtype(DType::F32)?;
|
||||
let (b, vocab_size) = logits.shape().r2()?;
|
||||
assert_eq!(b, 1);
|
||||
logits.reshape(vocab_size)
|
||||
logits.to_dtype(DType::F32)
|
||||
}
|
||||
|
||||
pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
|
||||
let wte = Embedding::load(cfg, vb.pp("model.embed_tokens"))?;
|
||||
let wte = embedding(cfg, vb.pp("model.embed_tokens"))?;
|
||||
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
|
||||
let norm = RmsNorm::load(cfg.hidden_size, vb.pp("model.norm"))?;
|
||||
let blocks: Vec<_> = (0..cfg.n_layer)
|
||||
|
Reference in New Issue
Block a user