diff --git a/candle-examples/examples/llama2-c/main.rs b/candle-examples/examples/llama2-c/main.rs index 2cf71bb5..c02c65b9 100644 --- a/candle-examples/examples/llama2-c/main.rs +++ b/candle-examples/examples/llama2-c/main.rs @@ -1,177 +1,21 @@ // https://github.com/karpathy/llama2.c -#![allow(dead_code)] -#![allow(unused)] #[cfg(feature = "mkl")] extern crate intel_mkl_src; mod model; +mod weights; use clap::{Parser, Subcommand}; use anyhow::{Error as E, Result}; -use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; -use candle::{DType, Device, Error, IndexOp, Layout, Shape, Tensor}; -use candle_nn::{Embedding, Linear, VarBuilder}; +use byteorder::{LittleEndian, ReadBytesExt}; +use candle::{IndexOp, Tensor}; use candle_transformers::generation::LogitsProcessor; use std::io::Write; use tokenizers::Tokenizer; use model::{Config, Llama}; - -struct TransformerWeights { - // token embedding table - token_embedding_table: Tensor, // (vocab_size, dim) - // weights for rmsnorms - rms_att_weight: Tensor, // (layer, dim) rmsnorm weights - rms_ffn_weight: Tensor, // (layer, dim) - // weights for matmuls - wq: Tensor, // (layer, dim, dim) - wk: Tensor, // (layer, dim, dim) - wv: Tensor, // (layer, dim, dim) - wo: Tensor, // (layer, dim, dim) - // weights for ffn - w1: Tensor, // (layer, hidden_dim, dim) - w2: Tensor, // (layer, dim, hidden_dim) - w3: Tensor, // (layer, hidden_dim, dim) - // final rmsnorm - rms_final_weight: Tensor, // (dim,) - // freq_cis for RoPE relatively positional embeddings - freq_cis_real: Tensor, // (seq_len, head_size/2) - freq_cis_imag: Tensor, // (seq_len, head_size/2) -} - -fn read_i32(r: &mut R) -> Result { - let mut buf = [0u8; 4]; - r.read_exact(&mut buf)?; - Ok(i32::from_le_bytes(buf)) -} - -fn read_tensor>( - r: &mut R, - shape: S, - dev: &Device, -) -> Result { - let shape = shape.into(); - let mut data_t = vec![0f32; shape.elem_count()]; - r.read_f32_into::(&mut data_t)?; - let tensor = Tensor::from_vec(data_t, shape, dev)?; - Ok(tensor) -} - -impl Config { - fn from_reader(r: &mut R) -> Result { - let dim = read_i32(r)? as usize; - let hidden_dim = read_i32(r)? as usize; - let n_layers = read_i32(r)? as usize; - let n_heads = read_i32(r)? as usize; - let n_kv_heads = read_i32(r)? as usize; - let vocab_size = read_i32(r)? as usize; - let seq_len = read_i32(r)? as usize; - Ok(Self { - dim, - hidden_dim, - n_layers, - n_heads, - n_kv_heads, - vocab_size, - seq_len, - norm_eps: 1e-5, - }) - } - - fn head_size(&self) -> usize { - self.dim / self.n_heads - } -} - -impl TransformerWeights { - fn from_reader(r: &mut R, c: &Config, dev: &Device) -> Result { - let token_embedding_table = read_tensor(r, (c.vocab_size, c.dim), dev)?; - let rms_att_weight = read_tensor(r, (c.n_layers, c.dim), dev)?; - let wq = read_tensor(r, (c.n_layers, c.dim, c.dim), dev)?; - let wk = read_tensor(r, (c.n_layers, c.dim, c.dim), dev)?; - let wv = read_tensor(r, (c.n_layers, c.dim, c.dim), dev)?; - let wo = read_tensor(r, (c.n_layers, c.dim, c.dim), dev)?; - let rms_ffn_weight = read_tensor(r, (c.n_layers, c.dim), dev)?; - let w1 = read_tensor(r, (c.n_layers, c.hidden_dim, c.dim), dev)?; - let w2 = read_tensor(r, (c.n_layers, c.dim, c.hidden_dim), dev)?; - let w3 = read_tensor(r, (c.n_layers, c.hidden_dim, c.dim), dev)?; - let rms_final_weight = read_tensor(r, c.dim, dev)?; - let head_size = c.head_size(); - let freq_cis_real = read_tensor(r, (c.seq_len, head_size / 2), dev)?; - let freq_cis_imag = read_tensor(r, (c.seq_len, head_size / 2), dev)?; - Ok(Self { - token_embedding_table, - rms_att_weight, - wq, - wk, - wv, - wo, - rms_ffn_weight, - w1, - w2, - w3, - rms_final_weight, - freq_cis_real, - freq_cis_imag, - }) - } - - fn var_builder(&self, cfg: &Config, device: &Device) -> Result { - let mut ws = std::collections::HashMap::new(); - let mut insert = |name: &str, t: Tensor| { - ws.insert(name.to_string(), t); - }; - insert("rot.freq_cis_real", self.freq_cis_real.clone()); - insert("rot.freq_cis_imag", self.freq_cis_imag.clone()); - insert( - "model.embed_tokens.weight", - self.token_embedding_table.clone(), - ); - insert("lm_head.weight", self.token_embedding_table.clone()); - insert("model.norm.weight", self.rms_final_weight.clone()); - for layer in 0..cfg.n_layers { - ws.insert( - format!("model.layers.{layer}.self_attn.q_proj.weight"), - self.wq.i(layer)?, - ); - ws.insert( - format!("model.layers.{layer}.self_attn.k_proj.weight"), - self.wk.i(layer)?, - ); - ws.insert( - format!("model.layers.{layer}.self_attn.v_proj.weight"), - self.wv.i(layer)?, - ); - ws.insert( - format!("model.layers.{layer}.self_attn.o_proj.weight"), - self.wo.i(layer)?, - ); - ws.insert( - format!("model.layers.{layer}.mlp.gate_proj.weight"), - self.w1.i(layer)?, - ); - ws.insert( - format!("model.layers.{layer}.mlp.down_proj.weight"), - self.w2.i(layer)?, - ); - ws.insert( - format!("model.layers.{layer}.mlp.up_proj.weight"), - self.w3.i(layer)?, - ); - ws.insert( - format!("model.layers.{layer}.input_layernorm.weight"), - self.rms_att_weight.i(layer)?, - ); - ws.insert( - format!("model.layers.{layer}.post_attention_layernorm.weight"), - self.rms_ffn_weight.i(layer)?, - ); - } - let vb = VarBuilder::from_tensors(ws, DType::F32, device); - Ok(vb) - } -} +use weights::TransformerWeights; #[derive(Parser, Debug, Clone)] struct InferenceCmd { @@ -308,6 +152,8 @@ fn run_eval(args: &EvaluationCmd, common_args: &Args) -> Result<()> { tokens.concat() } Some(pretokenized_dir) => { + // Use shard 0 for the test split, similar to llama2.c + // https://github.com/karpathy/llama2.c/blob/ce05cc28cf1e3560b873bb21837638a434520a67/tinystories.py#L121 let path = std::path::PathBuf::from(pretokenized_dir).join("data00.bin"); let bytes = std::fs::read(path)?; // Tokens are encoded as u16. @@ -377,7 +223,6 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> { if tokens.len() >= model.config.seq_len { break; } - let start_gen = std::time::Instant::now(); let context_size = if index > 0 { 1 } else { tokens.len() }; let ctxt = &tokens[tokens.len().saturating_sub(context_size)..]; let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?; diff --git a/candle-examples/examples/llama2-c/model.rs b/candle-examples/examples/llama2-c/model.rs index a92367e6..618bf67c 100644 --- a/candle-examples/examples/llama2-c/model.rs +++ b/candle-examples/examples/llama2-c/model.rs @@ -106,7 +106,6 @@ struct CausalSelfAttention { n_key_value_head: usize, head_dim: usize, cache: Cache, - max_seq_len: usize, } impl CausalSelfAttention { @@ -198,7 +197,6 @@ impl CausalSelfAttention { n_key_value_head: cfg.n_kv_heads, head_dim: cfg.dim / cfg.n_heads, cache: cache.clone(), - max_seq_len: cfg.seq_len, }) } } @@ -291,7 +289,7 @@ pub struct Llama { impl Llama { pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result { - let (_b_sz, seq_len) = x.dims2()?; + let (_b_sz, _seq_len) = x.dims2()?; let mut x = self.wte.forward(x)?; for (block_idx, block) in self.blocks.iter().enumerate() { x = block.forward(&x, index_pos, block_idx)?; diff --git a/candle-examples/examples/llama2-c/weights.rs b/candle-examples/examples/llama2-c/weights.rs new file mode 100644 index 00000000..ae1fd6d9 --- /dev/null +++ b/candle-examples/examples/llama2-c/weights.rs @@ -0,0 +1,161 @@ +use anyhow::Result; +use byteorder::{LittleEndian, ReadBytesExt}; +use candle::{DType, Device, IndexOp, Shape, Tensor}; +use candle_nn::VarBuilder; + +use crate::model::Config; + +pub struct TransformerWeights { + // token embedding table + token_embedding_table: Tensor, // (vocab_size, dim) + // weights for rmsnorms + rms_att_weight: Tensor, // (layer, dim) rmsnorm weights + rms_ffn_weight: Tensor, // (layer, dim) + // weights for matmuls + wq: Tensor, // (layer, dim, dim) + wk: Tensor, // (layer, dim, dim) + wv: Tensor, // (layer, dim, dim) + wo: Tensor, // (layer, dim, dim) + // weights for ffn + w1: Tensor, // (layer, hidden_dim, dim) + w2: Tensor, // (layer, dim, hidden_dim) + w3: Tensor, // (layer, hidden_dim, dim) + // final rmsnorm + rms_final_weight: Tensor, // (dim,) + // freq_cis for RoPE relatively positional embeddings + freq_cis_real: Tensor, // (seq_len, head_size/2) + freq_cis_imag: Tensor, // (seq_len, head_size/2) +} + +fn read_i32(r: &mut R) -> Result { + let mut buf = [0u8; 4]; + r.read_exact(&mut buf)?; + Ok(i32::from_le_bytes(buf)) +} + +fn read_tensor>( + r: &mut R, + shape: S, + dev: &Device, +) -> Result { + let shape = shape.into(); + let mut data_t = vec![0f32; shape.elem_count()]; + r.read_f32_into::(&mut data_t)?; + let tensor = Tensor::from_vec(data_t, shape, dev)?; + Ok(tensor) +} + +impl Config { + pub fn from_reader(r: &mut R) -> Result { + let dim = read_i32(r)? as usize; + let hidden_dim = read_i32(r)? as usize; + let n_layers = read_i32(r)? as usize; + let n_heads = read_i32(r)? as usize; + let n_kv_heads = read_i32(r)? as usize; + let vocab_size = read_i32(r)? as usize; + let seq_len = read_i32(r)? as usize; + Ok(Self { + dim, + hidden_dim, + n_layers, + n_heads, + n_kv_heads, + vocab_size, + seq_len, + norm_eps: 1e-5, + }) + } + + pub fn head_size(&self) -> usize { + self.dim / self.n_heads + } +} + +impl TransformerWeights { + pub fn from_reader(r: &mut R, c: &Config, dev: &Device) -> Result { + let token_embedding_table = read_tensor(r, (c.vocab_size, c.dim), dev)?; + let rms_att_weight = read_tensor(r, (c.n_layers, c.dim), dev)?; + let wq = read_tensor(r, (c.n_layers, c.dim, c.dim), dev)?; + let wk = read_tensor(r, (c.n_layers, c.dim, c.dim), dev)?; + let wv = read_tensor(r, (c.n_layers, c.dim, c.dim), dev)?; + let wo = read_tensor(r, (c.n_layers, c.dim, c.dim), dev)?; + let rms_ffn_weight = read_tensor(r, (c.n_layers, c.dim), dev)?; + let w1 = read_tensor(r, (c.n_layers, c.hidden_dim, c.dim), dev)?; + let w2 = read_tensor(r, (c.n_layers, c.dim, c.hidden_dim), dev)?; + let w3 = read_tensor(r, (c.n_layers, c.hidden_dim, c.dim), dev)?; + let rms_final_weight = read_tensor(r, c.dim, dev)?; + let head_size = c.head_size(); + let freq_cis_real = read_tensor(r, (c.seq_len, head_size / 2), dev)?; + let freq_cis_imag = read_tensor(r, (c.seq_len, head_size / 2), dev)?; + Ok(Self { + token_embedding_table, + rms_att_weight, + wq, + wk, + wv, + wo, + rms_ffn_weight, + w1, + w2, + w3, + rms_final_weight, + freq_cis_real, + freq_cis_imag, + }) + } + + pub fn var_builder(&self, cfg: &Config, device: &Device) -> Result { + let mut ws = std::collections::HashMap::new(); + let mut insert = |name: &str, t: Tensor| { + ws.insert(name.to_string(), t); + }; + insert("rot.freq_cis_real", self.freq_cis_real.clone()); + insert("rot.freq_cis_imag", self.freq_cis_imag.clone()); + insert( + "model.embed_tokens.weight", + self.token_embedding_table.clone(), + ); + insert("lm_head.weight", self.token_embedding_table.clone()); + insert("model.norm.weight", self.rms_final_weight.clone()); + for layer in 0..cfg.n_layers { + ws.insert( + format!("model.layers.{layer}.self_attn.q_proj.weight"), + self.wq.i(layer)?, + ); + ws.insert( + format!("model.layers.{layer}.self_attn.k_proj.weight"), + self.wk.i(layer)?, + ); + ws.insert( + format!("model.layers.{layer}.self_attn.v_proj.weight"), + self.wv.i(layer)?, + ); + ws.insert( + format!("model.layers.{layer}.self_attn.o_proj.weight"), + self.wo.i(layer)?, + ); + ws.insert( + format!("model.layers.{layer}.mlp.gate_proj.weight"), + self.w1.i(layer)?, + ); + ws.insert( + format!("model.layers.{layer}.mlp.down_proj.weight"), + self.w2.i(layer)?, + ); + ws.insert( + format!("model.layers.{layer}.mlp.up_proj.weight"), + self.w3.i(layer)?, + ); + ws.insert( + format!("model.layers.{layer}.input_layernorm.weight"), + self.rms_att_weight.i(layer)?, + ); + ws.insert( + format!("model.layers.{layer}.post_attention_layernorm.weight"), + self.rms_ffn_weight.i(layer)?, + ); + } + let vb = VarBuilder::from_tensors(ws, DType::F32, device); + Ok(vb) + } +}