From ece3ec6167220b66a141605deb0a4ffd0136120d Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 28 Jun 2023 12:27:03 +0000 Subject: [PATCH] Final updates -> moving to deterministic for easier comparison. --- candle-core/examples/llama/main.rs | 64 ++++++++-------- candle-core/examples/llama/var_store.rs | 98 ++++++++++++++++++++++--- 2 files changed, 120 insertions(+), 42 deletions(-) diff --git a/candle-core/examples/llama/main.rs b/candle-core/examples/llama/main.rs index 066025b1..8af465a9 100644 --- a/candle-core/examples/llama/main.rs +++ b/candle-core/examples/llama/main.rs @@ -19,9 +19,7 @@ use candle_hub::{api::Api, Repo, RepoType}; use std::collections::HashMap; use std::sync::{Arc, Mutex}; -// mod var_store; -// use var_store::VarBuilder; - +mod var_store; mod weights; const CONTEXT_SIZE: usize = 512; @@ -196,11 +194,6 @@ fn silu(xs: &Tensor) -> Result { impl Mlp { fn new(c_fc1: Linear, c_fc2: Linear, c_proj: Linear) -> Self { - // let n_hidden = 8 * n_embd / 3; - // let n_hidden = (n_hidden - 1) / 256 * 256 + 256; - // let c_fc1 = Linear::new_no_bias(&vb / "c_fc1", n_embd, n_hidden)?; - // let c_fc2 = Linear::new_no_bias(&vb / "c_fc2", n_embd, n_hidden)?; - // let c_proj = Linear::new_no_bias(&vb / "c_proj", n_hidden, n_embd)?; Self { c_fc1, c_fc2, @@ -244,10 +237,6 @@ impl Cache { let mask: Vec<_> = (0..t) .flat_map(|i| (0..t).map(move |j| u32::from(j > i))) .collect(); - // Once lower_triangle is available, use the followig: - //let mask = Tensor::new(1u32, &device)? - // .broadcast_as(&[t, t])? - // .lower_triangle()? let mask = Tensor::from_slice(&mask, (t, t), &self.device)?; masks.insert(t, mask.clone()); Ok(mask) @@ -265,13 +254,10 @@ struct CausalSelfAttention { impl CausalSelfAttention { fn new(c_attn: Linear, c_proj: Linear, n_head: usize, cache: &Cache) -> Self { - // let c_attn = Linear::new_no_bias(&vb / "c_attn", n_embd, 3 * n_embd)?; - // let c_proj = Linear::new_no_bias(&vb / "c_proj", n_embd, n_embd)?; Self { c_attn, c_proj, n_head, - // n_embd, cache: cache.clone(), } } @@ -408,6 +394,10 @@ struct Args { #[arg(long)] cpu: bool, + /// Use npy instead of safetensors + #[arg(long)] + npy: bool, + /// The temperature used to generate samples. #[arg(long, default_value_t = 1.0)] temperature: f64, @@ -419,7 +409,7 @@ struct Args { #[tokio::main] async fn main() -> Result<()> { - use rand::prelude::*; + //use rand::prelude::*; use tokenizers::Tokenizer; let args = Args::parse(); @@ -431,6 +421,7 @@ async fn main() -> Result<()> { let api = Api::new()?; let repo = Repo::new("Narsil/amall-7b".to_string(), RepoType::Model); let tokenizer_filename = api.get(&repo, "tokenizer.json").await?; + println!("Filename {tokenizer_filename:?}"); let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let mut tokens = tokenizer .encode(START_PROMPT, true) @@ -446,30 +437,25 @@ async fn main() -> Result<()> { let filename = api.get(&repo, rfilename).await?; filenames.push(filename); } - // let weight_path = std::path::Path::new("llama.npz"); - // let weights = if weight_path.exists() { - // println!("loading weights from {weight_path:?}"); - // let start_load = std::time::Instant::now(); - // let tensors = Tensor::read_npz(weight_path)?; - // println!("loaded weights in {:?}", start_load.elapsed()); - // let tensors: std::collections::HashMap = tensors.into_iter().collect(); - // Some(tensors) - // } else { - // println!("cannot find {weight_path:?}, using zero weights"); - // None - // }; - // let vb = VarBuilder::new::(&device, weights); println!("building the model"); let config = Config::config_7b(); let cache = Cache::new(&device); - let llama = Llama::load(&device, &filenames, &cache, &config)?; + let start = std::time::Instant::now(); + let llama = if args.npy { + println!("building the model (NPY)"); + Llama::load_npy(&device, &filenames, &cache, &config)? + } else { + println!("building the model (SF)"); + Llama::load(&device, &filenames, &cache, &config)? + }; + println!("Loaded in {:?}", start.elapsed()); println!("pre-computing the positional embeddings"); let freqs_cis = precompute_freqs_cis(&config, &device)?; println!("starting the inference loop"); let mut new_tokens = vec![]; - let mut rng = thread_rng(); + //let mut rng = thread_rng(); let start_gen = std::time::Instant::now(); for index in 0..args.sample_len { let ctxt = &tokens[tokens.len().saturating_sub(CONTEXT_SIZE)..]; @@ -477,8 +463,20 @@ async fn main() -> Result<()> { let logits = llama.forward(&input, &freqs_cis)?; let prs = (&logits / args.temperature)?.softmax(logits.rank() - 1)?; let logits_v: Vec = prs.to_vec1()?; - let distr = rand::distributions::WeightedIndex::new(&logits_v)?; - let next_token = distr.sample(&mut rng) as u32; + let next_token = logits_v + .iter() + .enumerate() + .fold((0, logits_v[0]), |(idx_max, val_max), (idx, val)| { + if &val_max > val { + (idx_max, val_max) + } else { + (idx, *val) + } + }) + .0 as u32; + // let distr = rand::distributions::WeightedIndex::new(&logits_v)?; + + // let next_token = distr.sample(&mut rng) as u32; tokens.push(next_token); new_tokens.push(next_token); println!( diff --git a/candle-core/examples/llama/var_store.rs b/candle-core/examples/llama/var_store.rs index 1a400edc..8771170e 100644 --- a/candle-core/examples/llama/var_store.rs +++ b/candle-core/examples/llama/var_store.rs @@ -1,5 +1,7 @@ +use super::*; use candle::{DType, Device, Result, Shape, Tensor, WithDType}; use std::collections::HashMap; +use std::path::PathBuf; use std::sync::Arc; #[allow(dead_code)] @@ -40,22 +42,15 @@ impl VarBuilder { self.vars.borrow().len() } - pub fn var>(&mut self, s: &str, shape: S) -> Result { - let shape = shape.into(); + pub fn var(&self, s: &str) -> Result { let path = format!("{}.{s}", self.path.join(".")); - let mut vars = self.vars.borrow_mut(); let parameter = match self.tensors.as_ref() { - None => Tensor::zeros(&shape, self.default_dtype, &self.default_device)?, + None => panic!("Cannot find tensors"), Some(tensors) => match tensors.get(&path) { Some(tensor) => tensor.to_device(&self.default_device)?, None => panic!("cannot find tensor for {path}"), }, }; - vars.push(NamedVar { - path, - dtype: self.default_dtype, - shape, - }); Ok(parameter) } @@ -90,3 +85,88 @@ impl std::ops::Div for VarBuilder { &self / rhs } } + +impl Embedding { + fn load_npy(vb: VarBuilder) -> Result { + let embeddings = vb.var("weight")?; + Ok(Self { embeddings }) + } +} + +impl Linear { + fn load_npy(vb: VarBuilder) -> Result { + let weight = vb.var("weight")?.t()?; + Ok(Self { weight }) + } +} + +impl RmsNorm { + fn load_npy(vb: VarBuilder) -> Result { + let scale = vb.var("scale")?; + Ok(Self::new(scale)) + } +} + +impl CausalSelfAttention { + fn load_npy(vb: VarBuilder, cache: &Cache, config: &Config) -> Result { + let c_attn = Linear::load_npy(&vb / "c_attn")?; + let c_proj = Linear::load_npy(&vb / "c_proj")?; + Ok(Self::new(c_attn, c_proj, config.n_head, cache)) + } +} + +impl Mlp { + fn load_npy(vb: VarBuilder) -> Result { + let c_fc1 = Linear::load_npy(&vb / "c_fc1")?; + let c_fc2 = Linear::load_npy(&vb / "c_fc2")?; + let c_proj = Linear::load_npy(&vb / "c_proj")?; + Ok(Self::new(c_fc1, c_fc2, c_proj)) + } +} + +impl Block { + fn load_npy(vb: VarBuilder, cache: &Cache, config: &Config) -> Result { + let attn = CausalSelfAttention::load_npy(&vb / "attn", cache, config)?; + let mlp = Mlp::load_npy(&vb / "mlp")?; + let input_layernorm = RmsNorm::load_npy(&vb / "rms_1")?; + let post_attention_layernorm = RmsNorm::load_npy(&vb / "rms_2")?; + Ok(Self::new( + input_layernorm, + attn, + post_attention_layernorm, + mlp, + )) + } +} + +impl Llama { + pub fn load_npy( + device: &Device, + _filenames: &[PathBuf], + cache: &Cache, + config: &Config, + ) -> Result { + let weight_path = std::path::Path::new("/data/llama.npz"); + let weights = if weight_path.exists() { + println!("loading weights from {weight_path:?}"); + let start_load = std::time::Instant::now(); + let tensors = Tensor::read_npz(weight_path)?; + println!("loaded weights in {:?}", start_load.elapsed()); + let tensors: std::collections::HashMap = tensors.into_iter().collect(); + Some(tensors) + } else { + println!("cannot find {weight_path:?}, using zero weights"); + None + }; + let vb = VarBuilder::new::(device, weights); + + let wte = Embedding::load_npy(&vb / "transformer" / "wte")?; + let lm_head = Linear::load_npy(&vb / "lm_head")?; + let norm = RmsNorm::load_npy(&vb / "transformer" / "ln_f")?; + let blocks: Vec<_> = (0..config.n_layer) + .map(|i| Block::load_npy(&vb / "transformer" / "h" / i, cache, config).unwrap()) + .collect(); + + Ok(Self::new(wte, blocks, norm, lm_head)) + } +}