Start sketching the llama example.

This commit is contained in:
laurent
2023-06-25 13:51:20 +01:00
parent a9c113248a
commit 90c140ff4b
4 changed files with 599 additions and 0 deletions

3
.gitignore vendored
View File

@ -12,3 +12,6 @@ Cargo.lock
# MSVC Windows builds of rustc generate these, which store debugging information
*.pdb
*tokenizer.json
*.npz

View File

@ -0,0 +1,68 @@
# Adapted from https://github.com/Lightning-AI/lit-llama/blob/main/scripts/convert_checkpoint.py
import sys
import torch
import numpy as np
from typing import Dict
from pathlib import Path
def tr(v):
return np.ascontiguousarray(np.transpose(v))
def convert_state_dict(state_dict: Dict[str, torch.Tensor], dtype: torch.dtype = torch.float16) -> Dict[str, torch.Tensor]:
print("start conv")
def get_and_remove(key, transpose=False):
v = state_dict[key].to(dtype).numpy()
if transpose:
v = tr(v)
del state_dict[key]
return v
converted = {}
converted["transformer.wte.weight"] = get_and_remove("tok_embeddings.weight")
converted["lm_head.weight"] = get_and_remove("output.weight", transpose=True)
converted["transformer.ln_f.scale"] = get_and_remove("norm.weight")
for layer_idx in sorted(set([k.split(".")[1] for k in state_dict if k.startswith("layers")])):
print(layer_idx)
# attention
# the wq, wk, wv from the FB model are stacked in our model as c_attn
converted[f"transformer.h.{layer_idx}.attn.c_attn.weight"] = tr(np.concatenate(
(
get_and_remove(f"layers.{layer_idx}.attention.wq.weight"),
get_and_remove(f"layers.{layer_idx}.attention.wk.weight"),
get_and_remove(f"layers.{layer_idx}.attention.wv.weight"),
)
))
converted[f"transformer.h.{layer_idx}.attn.c_proj.weight"] = tr(get_and_remove(
f"layers.{layer_idx}.attention.wo.weight"
))
# mlp
converted[f"transformer.h.{layer_idx}.mlp.c_fc1.weight"] = get_and_remove(
f"layers.{layer_idx}.feed_forward.w1.weight", transpose=True,
)
converted[f"transformer.h.{layer_idx}.mlp.c_proj.weight"] = get_and_remove(
f"layers.{layer_idx}.feed_forward.w2.weight", transpose=True,
)
converted[f"transformer.h.{layer_idx}.mlp.c_fc2.weight"] = get_and_remove(
f"layers.{layer_idx}.feed_forward.w3.weight", transpose=True,
)
# rms norm
converted[f"transformer.h.{layer_idx}.rms_1.scale"] = get_and_remove(f"layers.{layer_idx}.attention_norm.weight")
converted[f"transformer.h.{layer_idx}.rms_2.scale"] = get_and_remove(f"layers.{layer_idx}.ffn_norm.weight")
return converted
def convert_weights(llama_ckpt, *, output_npz: Path = Path("llama.npz"), dtype: str = "float16") -> None:
dt = getattr(torch, dtype, None)
if not isinstance(dt, torch.dtype):
raise ValueError(f"{dtype} is not a valid dtype.")
checkpoint = torch.load(llama_ckpt, map_location="cpu")
converted = convert_state_dict(checkpoint, dtype=dt)
del checkpoint
np.savez(output_npz, **converted)
if __name__ == "__main__":
if len(sys.argv) != 2:
raise ValueError(f"usage: convert_checkpoint.py ..../LLaMA/7B/consolidated.00.pth")
convert_weights(sys.argv[1])

450
examples/llama/main.rs Normal file
View File

@ -0,0 +1,450 @@
// An implementation of LLaMA https://github.com/facebookresearch/llama
//
// This is based on nanoGPT in a similar way to:
// https://github.com/Lightning-AI/lit-llama/blob/main/lit_llama/model.py
//
// The tokenizer config can be retrieved from:
// https://huggingface.co/hf-internal-testing/llama-tokenizer/raw/main/tokenizer.json
//
// In order to convert the llama weights to a .npz file, run:
// python examples/llama/convert_checkpoint.py ..../LLaMA/7B/consolidated.00.pth
use anyhow::{Error as E, Result};
use clap::Parser;
use candle::{Device, Tensor};
mod var_store;
use var_store::VarBuilder;
const CONTEXT_SIZE: usize = 512;
const START_PROMPT: &str = r"
EDWARD:
I wonder how our princely father 'scaped,
Or whether he be 'scaped away or no
From Clifford's and Northumberland's pursuit:
Had he been ta'en, we should have heard the news;
Had he been slain, we should have heard the news;
Or had he 'scaped, methinks we should have heard
The happy tidings of his good escape.
How fares my brother? why is he so sad?
RICHARD:
I cannot joy, until I be resolved
Where our right valiant father is become.
I saw him in the battle range about;
And watch'd him how he singled Clifford forth.
Methought he bore him in the thickest troop
As doth a lion in a herd of neat;
Or as a bear, encompass'd round with dogs,
Who having pinch'd a few and made them cry,
The rest stand all aloof, and bark at him.
So fared our father with his enemies;
So fled his enemies my warlike father:
Methinks, 'tis prize enough to be his son.
See how the morning opes her golden gates,
And takes her farewell of the glorious sun!
How well resembles it the prime of youth,
Trimm'd like a younker prancing to his love!
EDWARD:
Dazzle mine eyes, or do I see three suns?
RICHARD:
Three glorious suns, each one a perfect sun;
Not separated with the racking clouds,
But sever'd in a pale clear-shining sky.
See, see! they join, embrace, and seem to kiss,
As if they vow'd some league inviolable:
Now are they but one lamp, one light, one sun.
In this the heaven figures some event.
EDWARD:
'Tis wondrous strange, the like yet never heard of.
I think it cites us, brother, to the field,
That we, the sons of brave Plantagenet,
Each one already blazing by our meeds,
Should notwithstanding join our lights together
And over-shine the earth as this the world.
Whate'er it bodes, henceforward will I bear
Upon my target three fair-shining suns.
";
#[allow(dead_code)]
struct Config {
block_size: usize,
vocab_size: usize,
n_layer: usize,
n_head: usize,
n_embd: usize,
}
#[allow(dead_code)]
impl Config {
fn config_7b() -> Self {
Self {
block_size: 4096,
vocab_size: 32000,
n_layer: 32,
n_head: 32,
n_embd: 4096,
}
}
fn config_13b() -> Self {
Self {
block_size: 4096,
vocab_size: 32000,
n_layer: 40,
n_head: 40,
n_embd: 5120,
}
}
fn config_30b() -> Self {
Self {
block_size: 4096,
vocab_size: 32000,
n_layer: 60,
n_head: 52,
n_embd: 6656,
}
}
fn config_65b() -> Self {
Self {
block_size: 4096,
vocab_size: 32000,
n_layer: 80,
n_head: 64,
n_embd: 8192,
}
}
}
struct Embedding {
embeddings: Tensor,
}
impl Embedding {
fn new(mut vb: VarBuilder, vocab_size: usize, n_embd: usize) -> Result<Self> {
let embeddings = vb.var("weight", (vocab_size, n_embd))?;
Ok(Self { embeddings })
}
fn forward(&self, indexes: &Tensor) -> Result<Tensor> {
Ok(Tensor::embedding(indexes, &self.embeddings)?)
}
}
struct Linear {
ws: Tensor,
bs: Option<Tensor>,
}
impl Linear {
#[allow(dead_code)]
fn new(mut vb: VarBuilder, in_size: usize, out_size: usize) -> Result<Self> {
let ws = vb.var("weight", (in_size, out_size))?;
let bs = vb.var("bias", out_size)?;
Ok(Self { ws, bs: Some(bs) })
}
fn new_no_bias(mut vb: VarBuilder, in_size: usize, out_size: usize) -> Result<Self> {
let ws = vb.var("weight", (in_size, out_size))?;
Ok(Self { ws, bs: None })
}
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x = x.matmul(&self.ws)?;
let y = match &self.bs {
None => x,
Some(bs) => x.broadcast_add(bs)?,
};
Ok(y)
}
}
struct RmsNorm {
scale: Tensor,
size: usize,
}
impl RmsNorm {
fn new(mut vb: VarBuilder, size: usize) -> Result<Self> {
let scale = vb.var("scale", &[size])?;
Ok(Self { scale, size })
}
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let last_dim = x.dims().last().unwrap();
let norm_x = ((x * x)?.sum(&[x.rank() - 1])? / *last_dim as f64)?;
let x_normed = (x / (norm_x + 1e-5)?.sqrt()?)?;
let scale = self.scale.reshape(&[1, 1, self.size])?;
Ok((scale * x_normed)?)
}
}
struct Mlp {
c_fc1: Linear,
c_fc2: Linear,
c_proj: Linear,
}
fn silu(xs: &Tensor) -> Result<Tensor> {
Ok((xs / (xs.neg()?.exp()? + 1.0)?)?)
}
impl Mlp {
fn new(vb: VarBuilder, n_embd: usize) -> Result<Self> {
let n_hidden = 8 * n_embd / 3;
let n_hidden = (n_hidden - 1) / 256 * 256 + 256;
let c_fc1 = Linear::new_no_bias(&vb / "c_fc1", n_embd, n_hidden)?;
let c_fc2 = Linear::new_no_bias(&vb / "c_fc2", n_embd, n_hidden)?;
let c_proj = Linear::new_no_bias(&vb / "c_proj", n_hidden, n_embd)?;
Ok(Self {
c_fc1,
c_fc2,
c_proj,
})
}
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x = (silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?;
self.c_proj.forward(&x)
}
}
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
let shape = mask.shape();
let _on_true = Tensor::new(on_true, &on_false.device())?.broadcast_as(shape.dims())?;
// TODO: add an equivalent to where (or xla's select) so that we can use the following:
// let m = mask.where_cond(&on_true, on_false)?;
let m = on_false.clone();
Ok(m)
}
struct CausalSelfAttention {
c_attn: Linear,
c_proj: Linear,
n_head: usize,
n_embd: usize,
}
impl CausalSelfAttention {
fn new(vb: VarBuilder, n_head: usize, n_embd: usize) -> Result<Self> {
let c_attn = Linear::new_no_bias(&vb / "c_attn", n_embd, 3 * n_embd)?;
let c_proj = Linear::new_no_bias(&vb / "c_proj", n_embd, n_embd)?;
Ok(Self {
c_attn,
c_proj,
n_head,
n_embd,
})
}
fn apply_rotary_emb(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
let mut dims = x.dims().to_vec();
let v = dims.pop().unwrap();
dims.push(v / 2);
dims.push(2);
let x = x.reshape(dims)?;
let rank = x.rank();
let re_x = x.narrow(rank - 1, 0, 1)?;
let im_x = x.narrow(rank - 1, 1, 2)?;
let re_f = freqs_cis.narrow(rank - 1, 0, 1)?;
let im_f = freqs_cis.narrow(rank - 1, 1, 2)?;
let re = ((&re_x * &re_f)? - (&im_x * &im_f)?)?;
let im = ((&re_x * &im_f)? + (&im_x * &re_f)?)?;
let rope = Tensor::cat(&[&re, &im], rank - 1)?;
// TODO: Add the flatten op.
let mut dims = rope.dims().to_vec();
let v1 = dims.pop().unwrap();
let v2 = dims.pop().unwrap();
dims.push(v1 * v2);
let rope = rope.reshape(dims)?;
Ok(rope)
}
fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
let (b, t, c) = x.shape().r3()?;
let qkv = self.c_attn.forward(x)?;
let n_embd = self.n_embd;
let q = qkv.narrow(2, 0, n_embd)?;
let k = qkv.narrow(2, n_embd, 2 * n_embd)?;
let v = qkv.narrow(2, 2 * n_embd, 3 * n_embd)?;
let target_dim = [b, t, self.n_head, c / self.n_head];
let k = k.reshape(target_dim.as_slice())?.transpose(1, 2)?;
let q = q.reshape(target_dim.as_slice())?.transpose(1, 2)?;
let v = v.reshape(target_dim.as_slice())?.transpose(1, 2)?;
let q = self.apply_rotary_emb(&q, freqs_cis)?;
let k = self.apply_rotary_emb(&k, freqs_cis)?;
let k_shape = k.shape();
let att = (q.matmul(&k.t()?)? / (*k_shape.dims().last().unwrap() as f64).sqrt())?;
let device = x.device();
// TODO: If we support bool or u8 tensors, this would be better.
let mask = Tensor::new(1u32, &device)?
.broadcast_as(&[t, t])?
// TODO: .lower_triangle()?
.reshape(&[1, 1, t, t])?;
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
let y = att.softmax(att.rank() - 1)?.matmul(&v)?;
let y = y.transpose(1, 2)?.reshape(&[b, t, c])?;
let y = self.c_proj.forward(&y)?;
Ok(y)
}
}
struct Block {
rms_1: RmsNorm,
attn: CausalSelfAttention,
rms_2: RmsNorm,
mlp: Mlp,
}
impl Block {
fn new(vb: VarBuilder, config: &Config) -> Result<Self> {
let rms_1 = RmsNorm::new(&vb / "rms_1", config.n_embd)?;
let attn = CausalSelfAttention::new(&vb / "attn", config.n_head, config.n_embd)?;
let rms_2 = RmsNorm::new(&vb / "rms_2", config.n_embd)?;
let mlp = Mlp::new(&vb / "mlp", config.n_embd)?;
Ok(Self {
rms_1,
attn,
rms_2,
mlp,
})
}
fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
let x = (self.attn.forward(&self.rms_1.forward(x)?, freqs_cis)? + x)?;
let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + x)?;
Ok(x)
}
}
struct Llama {
wte: Embedding,
blocks: Vec<Block>,
ln_f: RmsNorm,
lm_head: Linear,
}
impl Llama {
fn new(vb: VarBuilder, config: &Config) -> Result<Self> {
let lm_head = Linear::new_no_bias(&vb / "lm_head", config.n_embd, config.vocab_size)?;
let wte = Embedding::new(
&vb / "transformer" / "wte",
config.vocab_size,
config.n_embd,
)?;
let blocks = (0..config.n_layer)
.map(|i| Block::new(&vb / "transformer" / "h" / i, config))
.collect::<Result<Vec<_>>>()?;
let ln_f = RmsNorm::new(&vb / "transformer" / "ln_f", config.n_embd)?;
Ok(Self {
wte,
blocks,
ln_f,
lm_head,
})
}
fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
let (_, t) = x.shape().r2()?;
let mut x = self.wte.forward(x)?;
for block in self.blocks.iter() {
x = block.forward(&x, freqs_cis)?;
}
let x = self.ln_f.forward(&x)?;
let x = x.narrow(1, t - 1, t)?;
let logits = self.lm_head.forward(&x)?;
Ok(logits)
}
}
fn precompute_freqs_cis(config: &Config) -> Result<Tensor> {
let seq_len = CONTEXT_SIZE;
let n_elem = config.n_embd / config.n_head;
let theta: Vec<_> = (0..n_elem)
.step_by(2)
.map(|i| 1f32 / 10000f32.powf(i as f32 / n_elem as f32))
.collect();
let arange: Vec<_> = (0..seq_len).map(|c| c as f32).collect();
let theta = Tensor::new(theta.as_slice(), &candle::Device::Cpu)?;
let arange = Tensor::new(arange.as_slice(), &candle::Device::Cpu)?;
let idx_theta = arange
.reshape((arange.elem_count(), 1))?
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
let shape = [1, 1, seq_len, n_elem / 2, 1];
let idx_theta_cos = idx_theta.cos()?.reshape(&shape)?;
let idx_theta_sin = idx_theta.sin()?.reshape(&shape)?;
let last_dim = idx_theta_cos.rank() - 1;
Ok(Tensor::cat(&[&idx_theta_cos, &idx_theta_sin], last_dim)?)
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// The temperature used to generate samples.
#[arg(long, default_value_t = 1.0)]
temperature: f64,
/// The length of the sample to generate (in tokens).
#[arg(long, default_value_t = 100)]
sample_len: usize,
}
fn main() -> Result<()> {
use rand::prelude::*;
use tokenizers::Tokenizer;
let args = Args::parse();
println!("loading tokenizer config");
let tokenizer = Tokenizer::from_file("llama-tokenizer.json").map_err(E::msg)?;
let mut tokens = tokenizer
.encode(START_PROMPT, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
println!("loading weights");
let start_load = std::time::Instant::now();
let vb = VarBuilder::new::<f32>(); // TODO: load the weights from llama.npz
println!("loaded weights in {:?}", start_load.elapsed());
println!("building the model");
let config = Config::config_7b();
let llama = Llama::new(vb, &config)?;
println!("pre-computing the positional embeddings");
let freqs_cis = precompute_freqs_cis(&config)?;
println!("starting the inference loop");
let mut new_tokens = vec![];
let mut rng = thread_rng();
for index in 0..args.sample_len {
let ctxt = &tokens[tokens.len().saturating_sub(CONTEXT_SIZE)..];
let input = Tensor::new(ctxt, &Device::Cpu)?;
let logits = llama.forward(&input, &freqs_cis)?;
let prs = (&logits / args.temperature)?.softmax(logits.rank() - 1)?;
let logits_v: Vec<f32> = prs.to_vec1()?;
let distr = rand::distributions::WeightedIndex::new(&logits_v)?;
let next_token = distr.sample(&mut rng) as u32;
tokens.push(next_token);
new_tokens.push(next_token);
println!(
"{} token: {} '{}'",
index + 1,
next_token,
tokenizer.decode(vec![next_token], true).map_err(E::msg)?
);
}
println!(
"----\n{}\n----",
tokenizer.decode(new_tokens, true).map_err(E::msg)?
);
Ok(())
}

View File

@ -0,0 +1,78 @@
use candle::{DType, Device, Result, Shape, Tensor, WithDType};
#[allow(dead_code)]
#[derive(Clone)]
struct NamedVar {
path: String,
dtype: DType,
shape: Shape,
}
#[derive(Clone)]
pub struct VarBuilder {
path: Vec<String>,
vars: std::rc::Rc<std::cell::RefCell<Vec<NamedVar>>>,
default_dtype: DType,
}
#[allow(dead_code)]
pub struct VarStore {
vars: Vec<NamedVar>,
}
impl VarBuilder {
pub fn new<B: WithDType>() -> Self {
let vars = std::rc::Rc::new(std::cell::RefCell::new(vec![]));
Self {
path: vec![],
vars,
default_dtype: B::DTYPE,
}
}
pub fn len(&self) -> usize {
self.vars.borrow().len()
}
pub fn var<S: Into<Shape>>(&mut self, s: &str, shape: S) -> Result<Tensor> {
let shape = shape.into();
let path = format!("{}.{s}", self.path.join("."));
let mut vars = self.vars.borrow_mut();
let parameter = Tensor::zeros(&shape, self.default_dtype, &Device::Cpu);
vars.push(NamedVar {
path,
dtype: self.default_dtype,
shape,
});
parameter
}
pub fn into_store(self) -> VarStore {
let vars = self.vars.borrow();
VarStore {
vars: vars.to_vec(),
}
}
}
impl<S: ToString> std::ops::Div<S> for &VarBuilder {
type Output = VarBuilder;
fn div(self, rhs: S) -> VarBuilder {
let mut path = self.path.clone();
path.push(rhs.to_string());
VarBuilder {
path,
vars: self.vars.clone(),
default_dtype: self.default_dtype,
}
}
}
impl<S: ToString> std::ops::Div<S> for VarBuilder {
type Output = VarBuilder;
fn div(self, rhs: S) -> VarBuilder {
&self / rhs
}
}