Llama batch (#144)

* Add a batch dimension to llama.

* Bugfixes.
This commit is contained in:
Laurent Mazare
2023-07-12 11:38:19 +01:00
committed by GitHub
parent bcf96e3cf3
commit b3b39cca92
3 changed files with 32 additions and 52 deletions

View File

@ -30,7 +30,7 @@ impl Tensor {
let mut current_dim = 0; let mut current_dim = 0;
for (i, indexer) in indexers.iter().enumerate() { for (i, indexer) in indexers.iter().enumerate() {
x = match indexer { x = match indexer {
TensorIndexer::Select(n) => x.get(*n)?, TensorIndexer::Select(n) => x.narrow(i, *n, 1)?.squeeze(i)?,
TensorIndexer::Narrow(left_bound, right_bound) => { TensorIndexer::Narrow(left_bound, right_bound) => {
let start = match left_bound { let start = match left_bound {
Bound::Included(n) => *n, Bound::Included(n) => *n,

View File

@ -12,8 +12,6 @@
#[cfg(feature = "mkl")] #[cfg(feature = "mkl")]
extern crate intel_mkl_src; extern crate intel_mkl_src;
// TODO: This does not use a batch dimension. If adding it back, be cautious about the
// transposition operations.
use anyhow::{Error as E, Result}; use anyhow::{Error as E, Result};
use clap::Parser; use clap::Parser;
use rand::{distributions::Distribution, SeedableRng}; use rand::{distributions::Distribution, SeedableRng};
@ -200,13 +198,14 @@ fn main() -> Result<()> {
tokens.len() tokens.len()
}; };
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..]; let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, &device)?; let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
let freqs_cis = if cache.use_kv_cache { let freqs_cis = if cache.use_kv_cache {
freqs_cis.narrow(1, index_pos, ctxt.len())? freqs_cis.narrow(1, index_pos, ctxt.len())?
} else { } else {
freqs_cis.clone() freqs_cis.clone()
}; };
let logits = llama.forward(&input, &freqs_cis)?; let logits = llama.forward(&input, &freqs_cis)?;
let logits = logits.squeeze(0)?;
index_pos += ctxt.len(); index_pos += ctxt.len();
let next_token = if let Some(temperature) = args.temperature { let next_token = if let Some(temperature) = args.temperature {

View File

@ -1,5 +1,5 @@
use candle::{DType, Device, Result, Tensor, D}; use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::{Linear, VarBuilder}; use candle_nn::{Embedding, Linear, VarBuilder};
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
@ -71,23 +71,9 @@ fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
Ok(Linear::new(weight, None)) Ok(Linear::new(weight, None))
} }
struct Embedding { fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> {
embeddings: Tensor, let embeddings = vb.get((cfg.vocab_size, cfg.hidden_size), "weight")?;
} Ok(Embedding::new(embeddings, cfg.hidden_size))
impl Embedding {
fn new(embeddings: Tensor) -> Self {
Self { embeddings }
}
fn load(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let embeddings = vb.get((cfg.vocab_size, cfg.hidden_size), "weight")?;
Ok(Self::new(embeddings))
}
fn forward(&self, indexes: &Tensor) -> Result<Tensor> {
Tensor::embedding(indexes, &self.embeddings)
}
} }
struct RmsNorm { struct RmsNorm {
@ -108,15 +94,15 @@ impl RmsNorm {
let in_dtype = x.dtype(); let in_dtype = x.dtype();
// This is a no-op if x's dtype is already f32. // This is a no-op if x's dtype is already f32.
let x = x.to_dtype(DType::F32)?; let x = x.to_dtype(DType::F32)?;
let (seq_len, hidden_size) = x.shape().r2()?; let (b_sz, seq_len, hidden_size) = x.shape().r3()?;
let norm_x = ((&x * &x)?.sum(&[1])? / hidden_size as f64)?; let norm_x = ((&x * &x)?.sum(&[2])? / hidden_size as f64)?;
let norm_x = norm_x.broadcast_as((seq_len, hidden_size))?; let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?;
let x_normed = (x / (norm_x + 1e-5)?.sqrt()?)?; let x_normed = (x / (norm_x + 1e-5)?.sqrt()?)?;
let size = self.scale.shape().r1()?; let size = self.scale.shape().r1()?;
let scale = self let scale = self
.scale .scale
.to_dtype(DType::F32)? .to_dtype(DType::F32)?
.broadcast_as((seq_len, size))?; .broadcast_as((b_sz, seq_len, size))?;
let x = (scale * x_normed)?; let x = (scale * x_normed)?;
let x = x.to_dtype(in_dtype)?; let x = x.to_dtype(in_dtype)?;
Ok(x) Ok(x)
@ -143,8 +129,8 @@ impl CausalSelfAttention {
fn apply_rotary_emb(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> { fn apply_rotary_emb(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
let mut dims = x.dims().to_vec(); let mut dims = x.dims().to_vec();
let fcis_dims = freqs_cis.dims(); let fcis_dims = freqs_cis.dims();
let freqs_cis = if dims[1] < fcis_dims[1] { let freqs_cis = if dims[2] < fcis_dims[1] {
freqs_cis.narrow(1, 0, dims[1])? freqs_cis.narrow(1, 0, dims[2])?
} else { } else {
freqs_cis.clone() freqs_cis.clone()
}; };
@ -169,35 +155,34 @@ impl CausalSelfAttention {
fn forward(&self, x: &Tensor, freqs_cis: &Tensor, block_idx: usize) -> Result<Tensor> { fn forward(&self, x: &Tensor, freqs_cis: &Tensor, block_idx: usize) -> Result<Tensor> {
let x_dtype = x.dtype(); let x_dtype = x.dtype();
let (t, c) = x.shape().r2()?; let (b_sz, seq_len, n_embd) = x.shape().r3()?;
let qkv = self.c_attn.forward(x)?; let qkv = self.c_attn.forward(x)?;
let qkv = qkv.to_dtype(DType::F32)?; let qkv = qkv.to_dtype(DType::F32)?;
let n_embd = c; let q = qkv.narrow(D::Minus1, 0, n_embd)?;
let q = qkv.narrow(1, 0, n_embd)?; let k = qkv.narrow(D::Minus1, n_embd, n_embd)?;
let k = qkv.narrow(1, n_embd, n_embd)?; let v = qkv.narrow(D::Minus1, 2 * n_embd, n_embd)?;
let v = qkv.narrow(1, 2 * n_embd, n_embd)?; let target_dim = [b_sz, seq_len, self.n_head, n_embd / self.n_head];
let target_dim = [t, self.n_head, c / self.n_head]; let k = k.reshape(target_dim.as_slice())?.transpose(1, 2)?;
let k = k.reshape(target_dim.as_slice())?.transpose(0, 1)?; let q = q.reshape(target_dim.as_slice())?.transpose(1, 2)?;
let q = q.reshape(target_dim.as_slice())?.transpose(0, 1)?; let mut v = v.reshape(target_dim.as_slice())?.transpose(1, 2)?;
let mut v = v.reshape(target_dim.as_slice())?.transpose(0, 1)?;
let q = self.apply_rotary_emb(&q, freqs_cis)?; let q = self.apply_rotary_emb(&q, freqs_cis)?;
let mut k = self.apply_rotary_emb(&k, freqs_cis)?; let mut k = self.apply_rotary_emb(&k, freqs_cis)?;
if self.cache.use_kv_cache { if self.cache.use_kv_cache {
let mut cache = self.cache.kvs.lock().unwrap(); let mut cache = self.cache.kvs.lock().unwrap();
if let Some((cache_k, cache_v)) = &cache[block_idx] { if let Some((cache_k, cache_v)) = &cache[block_idx] {
k = Tensor::cat(&[cache_k, &k], 1)?.contiguous()?; k = Tensor::cat(&[cache_k, &k], 2)?.contiguous()?;
v = Tensor::cat(&[cache_v, &v], 1)?.contiguous()?; v = Tensor::cat(&[cache_v, &v], 2)?.contiguous()?;
let k_seq_len = k.dims()[1]; let k_seq_len = k.dims()[1];
if k_seq_len > MAX_SEQ_LEN { if k_seq_len > MAX_SEQ_LEN {
k = k k = k
.narrow(1, k_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)? .narrow(D::Minus1, k_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)?
.contiguous()? .contiguous()?
} }
let v_seq_len = v.dims()[1]; let v_seq_len = v.dims()[1];
if v_seq_len > 2 * MAX_SEQ_LEN { if v_seq_len > 2 * MAX_SEQ_LEN {
v = v v = v
.narrow(1, v_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)? .narrow(D::Minus1, v_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)?
.contiguous()? .contiguous()?
} }
} }
@ -205,12 +190,12 @@ impl CausalSelfAttention {
} }
let att = (q.matmul(&k.t()?)? / (k.dim(D::Minus1)? as f64).sqrt())?; let att = (q.matmul(&k.t()?)? / (k.dim(D::Minus1)? as f64).sqrt())?;
let mask = self.cache.mask(t)?.broadcast_as(att.shape())?; let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?;
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?; let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
let att = att.softmax(D::Minus1)?; let att = att.softmax(D::Minus1)?;
// Convert to contiguous as matmul doesn't support strided vs for now. // Convert to contiguous as matmul doesn't support strided vs for now.
let y = att.matmul(&v.contiguous()?)?; let y = att.matmul(&v.contiguous()?)?;
let y = y.transpose(0, 1)?.reshape(&[t, c])?; let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
let y = y.to_dtype(x_dtype)?; let y = y.to_dtype(x_dtype)?;
let y = self.c_proj.forward(&y)?; let y = self.c_proj.forward(&y)?;
Ok(y) Ok(y)
@ -336,23 +321,19 @@ impl Llama {
} }
pub fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> { pub fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
// TODO: Support for mini-batches? (i.e. r2) let (_b_sz, seq_len) = x.shape().r2()?;
let t = x.shape().r1()?;
let mut x = self.wte.forward(x)?; let mut x = self.wte.forward(x)?;
for (block_idx, block) in self.blocks.iter().enumerate() { for (block_idx, block) in self.blocks.iter().enumerate() {
x = block.forward(&x, freqs_cis, block_idx)?; x = block.forward(&x, freqs_cis, block_idx)?;
} }
let x = self.ln_f.forward(&x)?; let x = self.ln_f.forward(&x)?;
let x = x.narrow(0, t - 1, 1)?; let x = x.i((.., seq_len - 1, ..))?;
let logits = self.lm_head.forward(&x)?; let logits = self.lm_head.forward(&x)?;
let logits = logits.to_dtype(DType::F32)?; logits.to_dtype(DType::F32)
let (b, vocab_size) = logits.shape().r2()?;
assert_eq!(b, 1);
logits.reshape(vocab_size)
} }
pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> { pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
let wte = Embedding::load(cfg, vb.pp("model.embed_tokens"))?; let wte = embedding(cfg, vb.pp("model.embed_tokens"))?;
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?; let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
let norm = RmsNorm::load(cfg.hidden_size, vb.pp("model.norm"))?; let norm = RmsNorm::load(cfg.hidden_size, vb.pp("model.norm"))?;
let blocks: Vec<_> = (0..cfg.n_layer) let blocks: Vec<_> = (0..cfg.n_layer)