From b3b39cca92b084e473fc14b23575f8ecf201ebbf Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Wed, 12 Jul 2023 11:38:19 +0100 Subject: [PATCH] Llama batch (#144) * Add a batch dimension to llama. * Bugfixes. --- candle-core/src/indexer.rs | 2 +- candle-examples/examples/llama/main.rs | 5 +- candle-examples/examples/llama/model.rs | 77 ++++++++++--------------- 3 files changed, 32 insertions(+), 52 deletions(-) diff --git a/candle-core/src/indexer.rs b/candle-core/src/indexer.rs index 725ba732..0651b791 100644 --- a/candle-core/src/indexer.rs +++ b/candle-core/src/indexer.rs @@ -30,7 +30,7 @@ impl Tensor { let mut current_dim = 0; for (i, indexer) in indexers.iter().enumerate() { x = match indexer { - TensorIndexer::Select(n) => x.get(*n)?, + TensorIndexer::Select(n) => x.narrow(i, *n, 1)?.squeeze(i)?, TensorIndexer::Narrow(left_bound, right_bound) => { let start = match left_bound { Bound::Included(n) => *n, diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs index aeee6867..ac13dfee 100644 --- a/candle-examples/examples/llama/main.rs +++ b/candle-examples/examples/llama/main.rs @@ -12,8 +12,6 @@ #[cfg(feature = "mkl")] extern crate intel_mkl_src; -// TODO: This does not use a batch dimension. If adding it back, be cautious about the -// transposition operations. use anyhow::{Error as E, Result}; use clap::Parser; use rand::{distributions::Distribution, SeedableRng}; @@ -200,13 +198,14 @@ fn main() -> Result<()> { tokens.len() }; let ctxt = &tokens[tokens.len().saturating_sub(context_size)..]; - let input = Tensor::new(ctxt, &device)?; + let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?; let freqs_cis = if cache.use_kv_cache { freqs_cis.narrow(1, index_pos, ctxt.len())? } else { freqs_cis.clone() }; let logits = llama.forward(&input, &freqs_cis)?; + let logits = logits.squeeze(0)?; index_pos += ctxt.len(); let next_token = if let Some(temperature) = args.temperature { diff --git a/candle-examples/examples/llama/model.rs b/candle-examples/examples/llama/model.rs index 8ff15564..daab199d 100644 --- a/candle-examples/examples/llama/model.rs +++ b/candle-examples/examples/llama/model.rs @@ -1,5 +1,5 @@ -use candle::{DType, Device, Result, Tensor, D}; -use candle_nn::{Linear, VarBuilder}; +use candle::{DType, Device, IndexOp, Result, Tensor, D}; +use candle_nn::{Embedding, Linear, VarBuilder}; use std::collections::HashMap; use std::sync::{Arc, Mutex}; @@ -71,23 +71,9 @@ fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result { Ok(Linear::new(weight, None)) } -struct Embedding { - embeddings: Tensor, -} - -impl Embedding { - fn new(embeddings: Tensor) -> Self { - Self { embeddings } - } - - fn load(cfg: &Config, vb: VarBuilder) -> Result { - let embeddings = vb.get((cfg.vocab_size, cfg.hidden_size), "weight")?; - Ok(Self::new(embeddings)) - } - - fn forward(&self, indexes: &Tensor) -> Result { - Tensor::embedding(indexes, &self.embeddings) - } +fn embedding(cfg: &Config, vb: VarBuilder) -> Result { + let embeddings = vb.get((cfg.vocab_size, cfg.hidden_size), "weight")?; + Ok(Embedding::new(embeddings, cfg.hidden_size)) } struct RmsNorm { @@ -108,15 +94,15 @@ impl RmsNorm { let in_dtype = x.dtype(); // This is a no-op if x's dtype is already f32. let x = x.to_dtype(DType::F32)?; - let (seq_len, hidden_size) = x.shape().r2()?; - let norm_x = ((&x * &x)?.sum(&[1])? / hidden_size as f64)?; - let norm_x = norm_x.broadcast_as((seq_len, hidden_size))?; + let (b_sz, seq_len, hidden_size) = x.shape().r3()?; + let norm_x = ((&x * &x)?.sum(&[2])? / hidden_size as f64)?; + let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?; let x_normed = (x / (norm_x + 1e-5)?.sqrt()?)?; let size = self.scale.shape().r1()?; let scale = self .scale .to_dtype(DType::F32)? - .broadcast_as((seq_len, size))?; + .broadcast_as((b_sz, seq_len, size))?; let x = (scale * x_normed)?; let x = x.to_dtype(in_dtype)?; Ok(x) @@ -143,8 +129,8 @@ impl CausalSelfAttention { fn apply_rotary_emb(&self, x: &Tensor, freqs_cis: &Tensor) -> Result { let mut dims = x.dims().to_vec(); let fcis_dims = freqs_cis.dims(); - let freqs_cis = if dims[1] < fcis_dims[1] { - freqs_cis.narrow(1, 0, dims[1])? + let freqs_cis = if dims[2] < fcis_dims[1] { + freqs_cis.narrow(1, 0, dims[2])? } else { freqs_cis.clone() }; @@ -169,35 +155,34 @@ impl CausalSelfAttention { fn forward(&self, x: &Tensor, freqs_cis: &Tensor, block_idx: usize) -> Result { let x_dtype = x.dtype(); - let (t, c) = x.shape().r2()?; + let (b_sz, seq_len, n_embd) = x.shape().r3()?; let qkv = self.c_attn.forward(x)?; let qkv = qkv.to_dtype(DType::F32)?; - let n_embd = c; - let q = qkv.narrow(1, 0, n_embd)?; - let k = qkv.narrow(1, n_embd, n_embd)?; - let v = qkv.narrow(1, 2 * n_embd, n_embd)?; - let target_dim = [t, self.n_head, c / self.n_head]; - let k = k.reshape(target_dim.as_slice())?.transpose(0, 1)?; - let q = q.reshape(target_dim.as_slice())?.transpose(0, 1)?; - let mut v = v.reshape(target_dim.as_slice())?.transpose(0, 1)?; + let q = qkv.narrow(D::Minus1, 0, n_embd)?; + let k = qkv.narrow(D::Minus1, n_embd, n_embd)?; + let v = qkv.narrow(D::Minus1, 2 * n_embd, n_embd)?; + let target_dim = [b_sz, seq_len, self.n_head, n_embd / self.n_head]; + let k = k.reshape(target_dim.as_slice())?.transpose(1, 2)?; + let q = q.reshape(target_dim.as_slice())?.transpose(1, 2)?; + let mut v = v.reshape(target_dim.as_slice())?.transpose(1, 2)?; let q = self.apply_rotary_emb(&q, freqs_cis)?; let mut k = self.apply_rotary_emb(&k, freqs_cis)?; if self.cache.use_kv_cache { let mut cache = self.cache.kvs.lock().unwrap(); if let Some((cache_k, cache_v)) = &cache[block_idx] { - k = Tensor::cat(&[cache_k, &k], 1)?.contiguous()?; - v = Tensor::cat(&[cache_v, &v], 1)?.contiguous()?; + k = Tensor::cat(&[cache_k, &k], 2)?.contiguous()?; + v = Tensor::cat(&[cache_v, &v], 2)?.contiguous()?; let k_seq_len = k.dims()[1]; if k_seq_len > MAX_SEQ_LEN { k = k - .narrow(1, k_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)? + .narrow(D::Minus1, k_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)? .contiguous()? } let v_seq_len = v.dims()[1]; if v_seq_len > 2 * MAX_SEQ_LEN { v = v - .narrow(1, v_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)? + .narrow(D::Minus1, v_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)? .contiguous()? } } @@ -205,12 +190,12 @@ impl CausalSelfAttention { } let att = (q.matmul(&k.t()?)? / (k.dim(D::Minus1)? as f64).sqrt())?; - let mask = self.cache.mask(t)?.broadcast_as(att.shape())?; + let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?; let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?; let att = att.softmax(D::Minus1)?; // Convert to contiguous as matmul doesn't support strided vs for now. let y = att.matmul(&v.contiguous()?)?; - let y = y.transpose(0, 1)?.reshape(&[t, c])?; + let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?; let y = y.to_dtype(x_dtype)?; let y = self.c_proj.forward(&y)?; Ok(y) @@ -336,23 +321,19 @@ impl Llama { } pub fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result { - // TODO: Support for mini-batches? (i.e. r2) - let t = x.shape().r1()?; + let (_b_sz, seq_len) = x.shape().r2()?; let mut x = self.wte.forward(x)?; for (block_idx, block) in self.blocks.iter().enumerate() { x = block.forward(&x, freqs_cis, block_idx)?; } let x = self.ln_f.forward(&x)?; - let x = x.narrow(0, t - 1, 1)?; + let x = x.i((.., seq_len - 1, ..))?; let logits = self.lm_head.forward(&x)?; - let logits = logits.to_dtype(DType::F32)?; - let (b, vocab_size) = logits.shape().r2()?; - assert_eq!(b, 1); - logits.reshape(vocab_size) + logits.to_dtype(DType::F32) } pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result { - let wte = Embedding::load(cfg, vb.pp("model.embed_tokens"))?; + let wte = embedding(cfg, vb.pp("model.embed_tokens"))?; let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?; let norm = RmsNorm::load(cfg.hidden_size, vb.pp("model.norm"))?; let blocks: Vec<_> = (0..cfg.n_layer)