Refactor the llama example to make it more in sync with the other ones. (#139)

* Refactor the llama example to make it more in sync with the other ones.

* Make clippy happy.

* Properly load the safetensor weights.

* Get llama back to a working state for the safetensors case.
This commit is contained in:
Laurent Mazare
2023-07-11 17:20:55 +01:00
committed by GitHub
parent 64264d97c1
commit 760f1d7055
4 changed files with 383 additions and 613 deletions

View File

@ -20,11 +20,10 @@ use rand::{distributions::Distribution, SeedableRng};
use candle::{DType, Device, Tensor, D};
use candle_hub::{api::sync::Api, Repo, RepoType};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use candle_nn::VarBuilder;
mod var_store;
mod weights;
mod model;
use model::{Config, Llama};
const MAX_SEQ_LEN: usize = 4096;
#[cfg(feature = "mkl")]
@ -83,337 +82,6 @@ Whate'er it bodes, henceforward will I bear
Upon my target three fair-shining suns.
";
#[allow(dead_code)]
struct Config {
block_size: usize,
vocab_size: usize,
n_layer: usize,
n_head: usize,
n_embd: usize,
}
#[allow(dead_code)]
impl Config {
fn config_7b() -> Self {
Self {
block_size: 4096,
vocab_size: 32000,
n_layer: 32,
n_head: 32,
n_embd: 4096,
}
}
fn config_13b() -> Self {
Self {
block_size: 4096,
vocab_size: 32000,
n_layer: 40,
n_head: 40,
n_embd: 5120,
}
}
fn config_30b() -> Self {
Self {
block_size: 4096,
vocab_size: 32000,
n_layer: 60,
n_head: 52,
n_embd: 6656,
}
}
fn config_65b() -> Self {
Self {
block_size: 4096,
vocab_size: 32000,
n_layer: 80,
n_head: 64,
n_embd: 8192,
}
}
}
struct Embedding {
embeddings: Tensor,
}
impl Embedding {
fn new(embeddings: Tensor) -> Self {
Self { embeddings }
}
fn forward(&self, indexes: &Tensor) -> Result<Tensor> {
let embeddings = self.embeddings.to_dtype(DTYPE)?;
Ok(Tensor::embedding(indexes, &embeddings)?)
}
}
struct Linear {
weight: Tensor,
}
impl Linear {
fn new(weight: Tensor) -> Self {
Self { weight }
}
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let weight = self.weight.to_dtype(DTYPE)?;
let x = x.matmul(&weight.t()?)?;
Ok(x)
}
}
struct RmsNorm {
scale: Tensor,
}
impl RmsNorm {
fn new(scale: Tensor) -> Self {
Self { scale }
}
fn forward(&self, x: &Tensor) -> Result<Tensor> {
// This is a no-op if x's dtype is already f32.
let x = x.to_dtype(DType::F32)?;
let (seq_len, hidden_size) = x.shape().r2()?;
let norm_x = ((&x * &x)?.sum(&[1])? / hidden_size as f64)?;
let norm_x = norm_x.broadcast_as((seq_len, hidden_size))?;
let x_normed = (x / (norm_x + 1e-5)?.sqrt()?)?;
let size = self.scale.shape().r1()?;
let scale = self
.scale
.to_dtype(DType::F32)?
.broadcast_as((seq_len, size))?;
let x = (scale * x_normed)?;
let x = x.to_dtype(DTYPE)?;
Ok(x)
}
}
struct Mlp {
c_fc1: Linear,
c_fc2: Linear,
c_proj: Linear,
}
fn silu(xs: &Tensor) -> Result<Tensor> {
Ok((xs / (xs.neg()?.exp()? + 1.0)?)?)
}
impl Mlp {
fn new(c_fc1: Linear, c_fc2: Linear, c_proj: Linear) -> Self {
Self {
c_fc1,
c_fc2,
c_proj,
}
}
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x = (silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?;
self.c_proj.forward(&x)
}
}
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
let shape = mask.shape();
let on_true = Tensor::new(on_true, &on_false.device())?.broadcast_as(shape.dims())?;
let m = mask.where_cond(&on_true, on_false)?;
Ok(m)
}
#[derive(Clone)]
struct Cache {
masks: Arc<Mutex<HashMap<usize, Tensor>>>,
use_kv_cache: bool,
#[allow(clippy::type_complexity)]
kvs: Arc<Mutex<Vec<Option<(Tensor, Tensor)>>>>,
device: Device,
}
impl Cache {
fn new(use_kv_cache: bool, config: &Config, device: &Device) -> Self {
Self {
masks: Arc::new(Mutex::new(HashMap::new())),
use_kv_cache,
kvs: Arc::new(Mutex::new(vec![None; config.n_layer])),
device: device.clone(),
}
}
fn mask(&self, t: usize) -> Result<Tensor> {
let mut masks = self.masks.lock().unwrap();
if let Some(mask) = masks.get(&t) {
Ok(mask.clone())
} else {
// TODO: If we support bool or u8 tensors, this would be better.
let mask: Vec<_> = (0..t)
.flat_map(|i| (0..t).map(move |j| u32::from(j > i)))
.collect();
let mask = Tensor::from_slice(&mask, (t, t), &self.device)?;
masks.insert(t, mask.clone());
Ok(mask)
}
}
}
struct CausalSelfAttention {
c_attn: Linear,
c_proj: Linear,
n_head: usize,
cache: Cache,
}
impl CausalSelfAttention {
fn new(c_attn: Linear, c_proj: Linear, n_head: usize, cache: &Cache) -> Self {
Self {
c_attn,
c_proj,
n_head,
cache: cache.clone(),
}
}
fn apply_rotary_emb(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
let mut dims = x.dims().to_vec();
let fcis_dims = freqs_cis.dims();
let freqs_cis = if dims[1] < fcis_dims[1] {
freqs_cis.narrow(1, 0, dims[1])?
} else {
freqs_cis.clone()
};
let v = dims.pop().unwrap();
dims.push(v / 2);
dims.push(2);
let x = x.reshape(dims)?;
let re_x = x.narrow(D::Minus1, 0, 1)?;
let im_x = x.narrow(D::Minus1, 1, 1)?;
let re_f = freqs_cis
.narrow(D::Minus1, 0, 1)?
.broadcast_as(re_x.shape())?;
let im_f = freqs_cis
.narrow(D::Minus1, 1, 1)?
.broadcast_as(im_x.shape())?;
let re = ((&re_x * &re_f)? - (&im_x * &im_f)?)?;
let im = ((&re_x * &im_f)? + (&im_x * &re_f)?)?;
let rope = Tensor::cat(&[&re, &im], D::Minus1)?;
let rope = rope.flatten_from(D::Minus2)?;
Ok(rope)
}
fn forward(&self, x: &Tensor, freqs_cis: &Tensor, block_idx: usize) -> Result<Tensor> {
let (t, c) = x.shape().r2()?;
let qkv = self.c_attn.forward(x)?;
let qkv = qkv.to_dtype(DType::F32)?;
let n_embd = c;
let q = qkv.narrow(1, 0, n_embd)?;
let k = qkv.narrow(1, n_embd, n_embd)?;
let v = qkv.narrow(1, 2 * n_embd, n_embd)?;
let target_dim = [t, self.n_head, c / self.n_head];
let k = k.reshape(target_dim.as_slice())?.transpose(0, 1)?;
let q = q.reshape(target_dim.as_slice())?.transpose(0, 1)?;
let mut v = v.reshape(target_dim.as_slice())?.transpose(0, 1)?;
let q = self.apply_rotary_emb(&q, freqs_cis)?;
let mut k = self.apply_rotary_emb(&k, freqs_cis)?;
if self.cache.use_kv_cache {
let mut cache = self.cache.kvs.lock().unwrap();
if let Some((cache_k, cache_v)) = &cache[block_idx] {
k = Tensor::cat(&[cache_k, &k], 1)?.contiguous()?;
v = Tensor::cat(&[cache_v, &v], 1)?.contiguous()?;
let k_seq_len = k.dims()[1];
if k_seq_len > MAX_SEQ_LEN {
k = k
.narrow(1, k_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)?
.contiguous()?
}
let v_seq_len = v.dims()[1];
if v_seq_len > 2 * MAX_SEQ_LEN {
v = v
.narrow(1, v_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)?
.contiguous()?
}
}
cache[block_idx] = Some((k.clone(), v.clone()))
}
let att = (q.matmul(&k.t()?)? / (k.dim(D::Minus1)? as f64).sqrt())?;
let mask = self.cache.mask(t)?.broadcast_as(att.shape())?;
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
let att = att.softmax(D::Minus1)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
let y = att.matmul(&v.contiguous()?)?;
let y = y.transpose(0, 1)?.reshape(&[t, c])?;
let y = y.to_dtype(DTYPE)?;
let y = self.c_proj.forward(&y)?;
Ok(y)
}
}
struct Block {
rms_1: RmsNorm,
attn: CausalSelfAttention,
rms_2: RmsNorm,
mlp: Mlp,
}
impl Block {
fn new(rms_1: RmsNorm, attn: CausalSelfAttention, rms_2: RmsNorm, mlp: Mlp) -> Self {
Self {
rms_1,
attn,
rms_2,
mlp,
}
}
fn forward(&self, x: &Tensor, freqs_cis: &Tensor, block_idx: usize) -> Result<Tensor> {
let x = (self
.attn
.forward(&self.rms_1.forward(x)?, freqs_cis, block_idx)?
+ x)?;
let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + x)?;
Ok(x)
}
}
struct Llama {
wte: Embedding,
blocks: Vec<Block>,
ln_f: RmsNorm,
lm_head: Linear,
}
impl Llama {
fn new(wte: Embedding, blocks: Vec<Block>, ln_f: RmsNorm, lm_head: Linear) -> Self {
Self {
wte,
blocks,
ln_f,
lm_head,
}
}
fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
// TODO: Support for mini-batches? (i.e. r2)
let t = x.shape().r1()?;
let mut x = self.wte.forward(x)?;
for (block_idx, block) in self.blocks.iter().enumerate() {
x = block.forward(&x, freqs_cis, block_idx)?;
}
let x = self.ln_f.forward(&x)?;
let x = x.narrow(0, t - 1, 1)?;
let logits = self.lm_head.forward(&x)?;
let logits = logits.to_dtype(DType::F32)?;
let (b, vocab_size) = logits.shape().r2()?;
assert_eq!(b, 1);
Ok(logits.reshape(vocab_size)?)
}
}
fn precompute_freqs_cis(config: &Config, device: &Device) -> Result<Tensor> {
let n_elem = config.n_embd / config.n_head;
let theta: Vec<_> = (0..n_elem)
@ -474,19 +142,15 @@ fn main() -> Result<()> {
Device::new_cuda(0)?
};
let config = Config::config_7b();
let cache = Cache::new(!args.no_kv_cache, &config, &device);
let start = std::time::Instant::now();
let cache = model::Cache::new(!args.no_kv_cache, &config, &device);
let (llama, tokenizer_filename) = match args.npy {
Some(npy) => {
println!("building the model (NPY)");
let weights = Llama::load_npy(&device, &npy, &cache, &config)?;
let token_path = std::path::Path::new("llama-tokenizer.json").to_path_buf();
(weights, token_path)
Some(_) => {
todo!("fix numpy handling if we continue supporting it")
}
None => {
let api = Api::new()?;
let repo = Repo::new("Narsil/amall-7b".to_string(), RepoType::Model);
println!("building the model");
println!("loading the model weights");
let tokenizer_filename = api.get(&repo, "tokenizer.json")?;
let mut filenames = vec![];
for rfilename in [
@ -497,14 +161,20 @@ fn main() -> Result<()> {
filenames.push(filename);
}
println!("building the model (SF)");
(
Llama::load(&device, &filenames, &cache, &config)?,
tokenizer_filename,
)
println!("building the model");
let handles = filenames
.iter()
.map(|f| Ok(unsafe { candle::safetensors::MmapedFile::new(f.as_path())? }))
.collect::<Result<Vec<_>>>()?;
let tensors: Vec<_> = handles
.iter()
.map(|h| Ok(h.deserialize()?))
.collect::<Result<Vec<_>>>()?;
let vb = VarBuilder::from_safetensors(tensors, DTYPE, &device);
(Llama::load(vb, &cache, &config)?, tokenizer_filename)
}
};
println!("Loaded in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let prompt = args.prompt.as_ref().map_or(DEFAULT_PROMPT, |p| p.as_str());
let mut tokens = tokenizer