Llama batch (#144)

* Add a batch dimension to llama.

* Bugfixes.
This commit is contained in:
Laurent Mazare
2023-07-12 11:38:19 +01:00
committed by GitHub
parent bcf96e3cf3
commit b3b39cca92
3 changed files with 32 additions and 52 deletions

View File

@ -30,7 +30,7 @@ impl Tensor {
let mut current_dim = 0;
for (i, indexer) in indexers.iter().enumerate() {
x = match indexer {
TensorIndexer::Select(n) => x.get(*n)?,
TensorIndexer::Select(n) => x.narrow(i, *n, 1)?.squeeze(i)?,
TensorIndexer::Narrow(left_bound, right_bound) => {
let start = match left_bound {
Bound::Included(n) => *n,

View File

@ -12,8 +12,6 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
// TODO: This does not use a batch dimension. If adding it back, be cautious about the
// transposition operations.
use anyhow::{Error as E, Result};
use clap::Parser;
use rand::{distributions::Distribution, SeedableRng};
@ -200,13 +198,14 @@ fn main() -> Result<()> {
tokens.len()
};
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, &device)?;
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
let freqs_cis = if cache.use_kv_cache {
freqs_cis.narrow(1, index_pos, ctxt.len())?
} else {
freqs_cis.clone()
};
let logits = llama.forward(&input, &freqs_cis)?;
let logits = logits.squeeze(0)?;
index_pos += ctxt.len();
let next_token = if let Some(temperature) = args.temperature {

View File

@ -1,5 +1,5 @@
use candle::{DType, Device, Result, Tensor, D};
use candle_nn::{Linear, VarBuilder};
use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::{Embedding, Linear, VarBuilder};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
@ -71,23 +71,9 @@ fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
Ok(Linear::new(weight, None))
}
struct Embedding {
embeddings: Tensor,
}
impl Embedding {
fn new(embeddings: Tensor) -> Self {
Self { embeddings }
}
fn load(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let embeddings = vb.get((cfg.vocab_size, cfg.hidden_size), "weight")?;
Ok(Self::new(embeddings))
}
fn forward(&self, indexes: &Tensor) -> Result<Tensor> {
Tensor::embedding(indexes, &self.embeddings)
}
fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> {
let embeddings = vb.get((cfg.vocab_size, cfg.hidden_size), "weight")?;
Ok(Embedding::new(embeddings, cfg.hidden_size))
}
struct RmsNorm {
@ -108,15 +94,15 @@ impl RmsNorm {
let in_dtype = x.dtype();
// This is a no-op if x's dtype is already f32.
let x = x.to_dtype(DType::F32)?;
let (seq_len, hidden_size) = x.shape().r2()?;
let norm_x = ((&x * &x)?.sum(&[1])? / hidden_size as f64)?;
let norm_x = norm_x.broadcast_as((seq_len, hidden_size))?;
let (b_sz, seq_len, hidden_size) = x.shape().r3()?;
let norm_x = ((&x * &x)?.sum(&[2])? / hidden_size as f64)?;
let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?;
let x_normed = (x / (norm_x + 1e-5)?.sqrt()?)?;
let size = self.scale.shape().r1()?;
let scale = self
.scale
.to_dtype(DType::F32)?
.broadcast_as((seq_len, size))?;
.broadcast_as((b_sz, seq_len, size))?;
let x = (scale * x_normed)?;
let x = x.to_dtype(in_dtype)?;
Ok(x)
@ -143,8 +129,8 @@ impl CausalSelfAttention {
fn apply_rotary_emb(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
let mut dims = x.dims().to_vec();
let fcis_dims = freqs_cis.dims();
let freqs_cis = if dims[1] < fcis_dims[1] {
freqs_cis.narrow(1, 0, dims[1])?
let freqs_cis = if dims[2] < fcis_dims[1] {
freqs_cis.narrow(1, 0, dims[2])?
} else {
freqs_cis.clone()
};
@ -169,35 +155,34 @@ impl CausalSelfAttention {
fn forward(&self, x: &Tensor, freqs_cis: &Tensor, block_idx: usize) -> Result<Tensor> {
let x_dtype = x.dtype();
let (t, c) = x.shape().r2()?;
let (b_sz, seq_len, n_embd) = x.shape().r3()?;
let qkv = self.c_attn.forward(x)?;
let qkv = qkv.to_dtype(DType::F32)?;
let n_embd = c;
let q = qkv.narrow(1, 0, n_embd)?;
let k = qkv.narrow(1, n_embd, n_embd)?;
let v = qkv.narrow(1, 2 * n_embd, n_embd)?;
let target_dim = [t, self.n_head, c / self.n_head];
let k = k.reshape(target_dim.as_slice())?.transpose(0, 1)?;
let q = q.reshape(target_dim.as_slice())?.transpose(0, 1)?;
let mut v = v.reshape(target_dim.as_slice())?.transpose(0, 1)?;
let q = qkv.narrow(D::Minus1, 0, n_embd)?;
let k = qkv.narrow(D::Minus1, n_embd, n_embd)?;
let v = qkv.narrow(D::Minus1, 2 * n_embd, n_embd)?;
let target_dim = [b_sz, seq_len, self.n_head, n_embd / self.n_head];
let k = k.reshape(target_dim.as_slice())?.transpose(1, 2)?;
let q = q.reshape(target_dim.as_slice())?.transpose(1, 2)?;
let mut v = v.reshape(target_dim.as_slice())?.transpose(1, 2)?;
let q = self.apply_rotary_emb(&q, freqs_cis)?;
let mut k = self.apply_rotary_emb(&k, freqs_cis)?;
if self.cache.use_kv_cache {
let mut cache = self.cache.kvs.lock().unwrap();
if let Some((cache_k, cache_v)) = &cache[block_idx] {
k = Tensor::cat(&[cache_k, &k], 1)?.contiguous()?;
v = Tensor::cat(&[cache_v, &v], 1)?.contiguous()?;
k = Tensor::cat(&[cache_k, &k], 2)?.contiguous()?;
v = Tensor::cat(&[cache_v, &v], 2)?.contiguous()?;
let k_seq_len = k.dims()[1];
if k_seq_len > MAX_SEQ_LEN {
k = k
.narrow(1, k_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)?
.narrow(D::Minus1, k_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)?
.contiguous()?
}
let v_seq_len = v.dims()[1];
if v_seq_len > 2 * MAX_SEQ_LEN {
v = v
.narrow(1, v_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)?
.narrow(D::Minus1, v_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)?
.contiguous()?
}
}
@ -205,12 +190,12 @@ impl CausalSelfAttention {
}
let att = (q.matmul(&k.t()?)? / (k.dim(D::Minus1)? as f64).sqrt())?;
let mask = self.cache.mask(t)?.broadcast_as(att.shape())?;
let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?;
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
let att = att.softmax(D::Minus1)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
let y = att.matmul(&v.contiguous()?)?;
let y = y.transpose(0, 1)?.reshape(&[t, c])?;
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
let y = y.to_dtype(x_dtype)?;
let y = self.c_proj.forward(&y)?;
Ok(y)
@ -336,23 +321,19 @@ impl Llama {
}
pub fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
// TODO: Support for mini-batches? (i.e. r2)
let t = x.shape().r1()?;
let (_b_sz, seq_len) = x.shape().r2()?;
let mut x = self.wte.forward(x)?;
for (block_idx, block) in self.blocks.iter().enumerate() {
x = block.forward(&x, freqs_cis, block_idx)?;
}
let x = self.ln_f.forward(&x)?;
let x = x.narrow(0, t - 1, 1)?;
let x = x.i((.., seq_len - 1, ..))?;
let logits = self.lm_head.forward(&x)?;
let logits = logits.to_dtype(DType::F32)?;
let (b, vocab_size) = logits.shape().r2()?;
assert_eq!(b, 1);
logits.reshape(vocab_size)
logits.to_dtype(DType::F32)
}
pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
let wte = Embedding::load(cfg, vb.pp("model.embed_tokens"))?;
let wte = embedding(cfg, vb.pp("model.embed_tokens"))?;
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
let norm = RmsNorm::load(cfg.hidden_size, vb.pp("model.norm"))?;
let blocks: Vec<_> = (0..cfg.n_layer)