From e29dae044d56fb5f840f27d7317de51514dad749 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 27 Jun 2023 17:24:26 +0000 Subject: [PATCH 1/3] Tmp. --- candle-core/Cargo.toml | 3 + candle-core/examples/llama/main.rs | 146 +++++++++++--------------- candle-core/examples/llama/weights.rs | 144 +++++++++++++++++++++++++ candle-hub/src/api.rs | 8 +- candle-hub/src/lib.rs | 13 ++- 5 files changed, 229 insertions(+), 85 deletions(-) create mode 100644 candle-core/examples/llama/weights.rs diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml index 75b48df8..bd8d4104 100644 --- a/candle-core/Cargo.toml +++ b/candle-core/Cargo.toml @@ -27,6 +27,9 @@ anyhow = "1" clap = { version = "4.2.4", features = ["derive"] } rand = "0.8.5" tokenizers = { version = "0.13.3", default-features=false, features=["onig"] } +tokio = { version = "1.28.2", features = ["macros", "rt-multi-thread"] } +candle-hub = { path = "../candle-hub" } +memmap2 = "0.7.1" [features] default = ["cuda"] diff --git a/candle-core/examples/llama/main.rs b/candle-core/examples/llama/main.rs index eb681f4b..f1b2c3d4 100644 --- a/candle-core/examples/llama/main.rs +++ b/candle-core/examples/llama/main.rs @@ -15,11 +15,14 @@ use anyhow::{Error as E, Result}; use clap::Parser; use candle::{DType, Device, Tensor}; +use candle_hub::{Repo, api::Api, RepoType}; use std::collections::HashMap; use std::sync::{Arc, Mutex}; -mod var_store; -use var_store::VarBuilder; +// mod var_store; +// use var_store::VarBuilder; + +mod weights; const CONTEXT_SIZE: usize = 512; const START_PROMPT: &str = r" @@ -131,9 +134,8 @@ struct Embedding { } impl Embedding { - fn new(mut vb: VarBuilder, vocab_size: usize, n_embd: usize) -> Result { - let embeddings = vb.var("weight", (vocab_size, n_embd))?; - Ok(Self { embeddings }) + fn new(embeddings: Tensor) -> Self { + Self { embeddings } } fn forward(&self, indexes: &Tensor) -> Result { @@ -145,42 +147,27 @@ impl Embedding { } struct Linear { - ws: Tensor, - bs: Option, + weight: Tensor, } impl Linear { - #[allow(dead_code)] - fn new(mut vb: VarBuilder, in_size: usize, out_size: usize) -> Result { - let ws = vb.var("weight", (in_size, out_size))?; - let bs = vb.var("bias", out_size)?; - Ok(Self { ws, bs: Some(bs) }) - } - - fn new_no_bias(mut vb: VarBuilder, in_size: usize, out_size: usize) -> Result { - let ws = vb.var("weight", (in_size, out_size))?; - Ok(Self { ws, bs: None }) + fn new(weight: Tensor) -> Self { + Self { weight } } fn forward(&self, x: &Tensor) -> Result { - let x = x.matmul(&self.ws.to_dtype(DType::F32)?)?; - let y = match &self.bs { - None => x, - Some(bs) => x.broadcast_add(&bs.to_dtype(DType::F32)?)?, - }; - Ok(y) + let x = x.matmul(&self.weight.to_dtype(DType::F32)?.t()?)?; + Ok(x) } } struct RmsNorm { scale: Tensor, - size: usize, } impl RmsNorm { - fn new(mut vb: VarBuilder, size: usize) -> Result { - let scale = vb.var("scale", &[size])?; - Ok(Self { scale, size }) + fn new(scale: Tensor) -> Self { + Self { scale } } fn forward(&self, x: &Tensor) -> Result { @@ -188,10 +175,11 @@ impl RmsNorm { let norm_x = ((x * x)?.sum(&[1])? / hidden_size as f64)?; let norm_x = norm_x.broadcast_as((seq_len, hidden_size))?; let x_normed = (x / (norm_x + 1e-5)?.sqrt()?)?; + let size = self.scale.shape().r1()?; let scale = self .scale .to_dtype(DType::F32)? - .broadcast_as((seq_len, self.size))?; + .broadcast_as((seq_len, size))?; Ok((scale * x_normed)?) } } @@ -207,17 +195,17 @@ fn silu(xs: &Tensor) -> Result { } impl Mlp { - fn new(vb: VarBuilder, n_embd: usize) -> Result { - let n_hidden = 8 * n_embd / 3; - let n_hidden = (n_hidden - 1) / 256 * 256 + 256; - let c_fc1 = Linear::new_no_bias(&vb / "c_fc1", n_embd, n_hidden)?; - let c_fc2 = Linear::new_no_bias(&vb / "c_fc2", n_embd, n_hidden)?; - let c_proj = Linear::new_no_bias(&vb / "c_proj", n_hidden, n_embd)?; - Ok(Self { + fn new(c_fc1: Linear, c_fc2: Linear, c_proj: Linear) -> Self { + // let n_hidden = 8 * n_embd / 3; + // let n_hidden = (n_hidden - 1) / 256 * 256 + 256; + // let c_fc1 = Linear::new_no_bias(&vb / "c_fc1", n_embd, n_hidden)?; + // let c_fc2 = Linear::new_no_bias(&vb / "c_fc2", n_embd, n_hidden)?; + // let c_proj = Linear::new_no_bias(&vb / "c_proj", n_hidden, n_embd)?; + Self { c_fc1, c_fc2, c_proj, - }) + } } fn forward(&self, x: &Tensor) -> Result { @@ -256,7 +244,7 @@ impl Cache { let mask: Vec<_> = (0..t) .flat_map(|i| (0..t).map(move |j| u32::from(j > i))) .collect(); - // Once lower_triangle is available, use the following: + // Once lower_triangle is available, use the followig: //let mask = Tensor::new(1u32, &device)? // .broadcast_as(&[t, t])? // .lower_triangle()? @@ -271,21 +259,21 @@ struct CausalSelfAttention { c_attn: Linear, c_proj: Linear, n_head: usize, - n_embd: usize, + // n_embd: usize, cache: Cache, } impl CausalSelfAttention { - fn new(vb: VarBuilder, n_head: usize, n_embd: usize, cache: &Cache) -> Result { - let c_attn = Linear::new_no_bias(&vb / "c_attn", n_embd, 3 * n_embd)?; - let c_proj = Linear::new_no_bias(&vb / "c_proj", n_embd, n_embd)?; - Ok(Self { + fn new(c_attn: Linear, c_proj: Linear, n_head: usize, cache: &Cache) -> Self { + // let c_attn = Linear::new_no_bias(&vb / "c_attn", n_embd, 3 * n_embd)?; + // let c_proj = Linear::new_no_bias(&vb / "c_proj", n_embd, n_embd)?; + Self { c_attn, c_proj, n_head, - n_embd, + // n_embd, cache: cache.clone(), - }) + } } fn apply_rotary_emb(&self, x: &Tensor, freqs_cis: &Tensor) -> Result { @@ -313,7 +301,7 @@ impl CausalSelfAttention { fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result { let (t, c) = x.shape().r2()?; let qkv = self.c_attn.forward(x)?; - let n_embd = self.n_embd; + let n_embd = c; let q = qkv.narrow(1, 0, n_embd)?; let k = qkv.narrow(1, n_embd, n_embd)?; let v = qkv.narrow(1, 2 * n_embd, n_embd)?; @@ -344,17 +332,13 @@ struct Block { } impl Block { - fn new(vb: VarBuilder, cache: &Cache, config: &Config) -> Result { - let rms_1 = RmsNorm::new(&vb / "rms_1", config.n_embd)?; - let attn = CausalSelfAttention::new(&vb / "attn", config.n_head, config.n_embd, cache)?; - let rms_2 = RmsNorm::new(&vb / "rms_2", config.n_embd)?; - let mlp = Mlp::new(&vb / "mlp", config.n_embd)?; - Ok(Self { + fn new(rms_1: RmsNorm, attn: CausalSelfAttention, rms_2: RmsNorm, mlp: Mlp) -> Self { + Self { rms_1, attn, rms_2, mlp, - }) + } } fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result { @@ -372,23 +356,13 @@ struct Llama { } impl Llama { - fn new(vb: VarBuilder, cache: &Cache, config: &Config) -> Result { - let lm_head = Linear::new_no_bias(&vb / "lm_head", config.n_embd, config.vocab_size)?; - let wte = Embedding::new( - &vb / "transformer" / "wte", - config.vocab_size, - config.n_embd, - )?; - let blocks = (0..config.n_layer) - .map(|i| Block::new(&vb / "transformer" / "h" / i, cache, config)) - .collect::>>()?; - let ln_f = RmsNorm::new(&vb / "transformer" / "ln_f", config.n_embd)?; - Ok(Self { + fn new(wte: Embedding, blocks: Vec, ln_f: RmsNorm, lm_head: Linear) -> Self { + Self { wte, blocks, ln_f, lm_head, - }) + } } fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result { @@ -443,7 +417,8 @@ struct Args { sample_len: usize, } -fn main() -> Result<()> { +#[tokio::main] +async fn main() -> Result<()> { use rand::prelude::*; use tokenizers::Tokenizer; @@ -453,32 +428,39 @@ fn main() -> Result<()> { } else { Device::new_cuda(0)? }; - println!("loading tokenizer config"); - let tokenizer = Tokenizer::from_file("llama-tokenizer.json").map_err(E::msg)?; + let api = Api::new()?; + let repo = Repo::new("Narsil/amall-7b".to_string(), RepoType::Model); + let tokenizer_filename = api.get(&repo, "tokenizer.json").await?; + let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let mut tokens = tokenizer .encode(START_PROMPT, true) .map_err(E::msg)? .get_ids() .to_vec(); - let weight_path = std::path::Path::new("llama.npz"); - let weights = if weight_path.exists() { - println!("loading weights from {weight_path:?}"); - let start_load = std::time::Instant::now(); - let tensors = Tensor::read_npz(weight_path)?; - println!("loaded weights in {:?}", start_load.elapsed()); - let tensors: std::collections::HashMap = tensors.into_iter().collect(); - Some(tensors) - } else { - println!("cannot find {weight_path:?}, using zero weights"); - None - }; - let vb = VarBuilder::new::(&device, weights); + let mut filenames = vec![]; + for rfilename in ["model-00001-of-00002.safetensors", "model-00002-of-00002.safetensors"]{ + let filename = api.get(&repo, rfilename).await?; + filenames.push(filename); + } + // let weight_path = std::path::Path::new("llama.npz"); + // let weights = if weight_path.exists() { + // println!("loading weights from {weight_path:?}"); + // let start_load = std::time::Instant::now(); + // let tensors = Tensor::read_npz(weight_path)?; + // println!("loaded weights in {:?}", start_load.elapsed()); + // let tensors: std::collections::HashMap = tensors.into_iter().collect(); + // Some(tensors) + // } else { + // println!("cannot find {weight_path:?}, using zero weights"); + // None + // }; + // let vb = VarBuilder::new::(&device, weights); println!("building the model"); let config = Config::config_7b(); let cache = Cache::new(&device); - let llama = Llama::new(vb, &cache, &config)?; + let llama = Llama::load(&device, &filenames, &cache, &config)?; println!("pre-computing the positional embeddings"); let freqs_cis = precompute_freqs_cis(&config, &device)?; diff --git a/candle-core/examples/llama/weights.rs b/candle-core/examples/llama/weights.rs new file mode 100644 index 00000000..2a350f65 --- /dev/null +++ b/candle-core/examples/llama/weights.rs @@ -0,0 +1,144 @@ +use memmap2::MmapOptions; +use candle::{Device, Result, Shape, Tensor, WithDType}; +use std::fs::File; +use std::path::PathBuf; +use super::*; +use safetensors::{SafeTensors, tensor::{Dtype, TensorView}}; +use half::f16; + +fn convert<'a>(view: TensorView<'a>, device: &Device) -> Result{ + match view.dtype(){ + Dtype::F16 => { + let v = view.data(); + if (v.as_ptr() as usize) % 2 == 0 { + // SAFETY This is safe because we just checked that this + // was correctly aligned. + let data: &[f16] = + unsafe { std::slice::from_raw_parts(v.as_ptr() as *const f16, v.len() / 2) }; + Tensor::from_slice(data, view.shape(), device) + } else { + let mut c = Vec::with_capacity(v.len() / 2); + let mut i = 0; + while i < v.len() { + c.push(f16::from_le_bytes([v[i], v[i + 1]])); + i += 2; + } + Tensor::from_slice(&c, view.shape(), device) + } + + } + dt => todo!("Unhandled dtype {dt:?}") + } +} + +pub struct VarBuilder<'a>{ + routing: HashMap, + safetensors: Vec>, + device: Device, +} + + +impl<'a> VarBuilder<'a>{ + pub fn new(safetensors: Vec>, device: Device) -> Self{ + let mut routing = HashMap::new(); + for (index, sf) in safetensors.iter().enumerate(){ + for k in sf.names(){ + routing.insert(k.to_string(), index); + } + } + + Self{ + safetensors, + device, + routing + } + } + + pub fn get(&self, tensor_name: &str) -> Result{ + // Unwrap or 0 just to let the proper error flow. + let index = self.routing.get(tensor_name).unwrap_or(&0); + let view = self.safetensors[*index].tensor(tensor_name).unwrap(); + let tensor = convert(view, &self.device)?; + Ok(tensor) + + } +} + +impl Linear{ + fn load(prefix: &str, vb: &VarBuilder) -> Result{ + let weight = vb.get(&format!("{prefix}.weight"))?; + Ok(Self::new(weight)) + } + + fn load_multi(prefixes: &[&str], vb: &VarBuilder) -> Result{ + let weights: Vec<_> = prefixes.iter().map(|p| vb.get(&format!("{p}.weight")).unwrap()).collect(); + println!("shapes {:?}", weights.iter().map(|w| w.shape()).collect::>()); + let weight = Tensor::cat(&weights, 0)?; + Ok(Self::new(weight)) + } +} + +impl RmsNorm{ + fn load(prefix: &str, vb: &VarBuilder) -> Result{ + let scale = vb.get(&format!("{prefix}.weight"))?; + Ok(Self::new(scale)) + } +} + +impl CausalSelfAttention{ + fn load(prefix: &str, vb: &VarBuilder, cache: &Cache, config: &Config) -> Result{ + let c_attn = Linear::load_multi(&[&format!("{prefix}.q_proj"), &format!("{prefix}.k_proj"), &format!("{prefix}.v_proj")], vb)?; + let o_proj = Linear::load(&format!("{prefix}.o_proj"), vb)?; + Ok(Self::new(c_attn,o_proj, config.n_head, cache)) + } +} + +impl Mlp{ + fn load(prefix: &str, vb: &VarBuilder, config: &Config) -> Result{ + let c_fc1 = Linear::load(&format!("{prefix}.gate_proj"), vb)?; + let c_fc2 = Linear::load(&format!("{prefix}.up_proj"), vb)?; + let c_proj = Linear::load(&format!("{prefix}.down_proj"), vb)?; + Ok(Self::new(c_fc1, c_fc2, c_proj)) + } +} + +impl Block{ + fn load(prefix: &str, vb: &VarBuilder, cache: &Cache, config: &Config) -> Result{ + let attn = CausalSelfAttention::load(&format!("{prefix}.self_attn"), vb, cache, config)?; + let mlp = Mlp::load(&format!("{prefix}.mlp"), vb, config)?; + let input_layernorm = RmsNorm::load(&format!("{prefix}.input_layernorm"), vb)?; + let post_attention_layernorm = RmsNorm::load(&format!("{prefix}.post_attention_layernorm"), vb)?; + Ok(Self::new(input_layernorm, attn, post_attention_layernorm, mlp)) + } +} + +impl Llama{ + pub fn load(device: &Device, filenames: &[PathBuf], cache: &Cache, config: &Config) -> Result{ + let handles: Vec<_> = filenames.iter().map(|f| { + let file = File::open(f).unwrap(); + let buffer = unsafe { MmapOptions::new().map(&file).unwrap() }; + buffer + }).collect(); + let tensors: Vec<_> = handles.iter().map(|h| { + let tensors = SafeTensors::deserialize(h).unwrap(); + tensors + }).collect(); + + let vb = VarBuilder::new(tensors, device.clone()); + + let embedding = vb.get("model.embed_tokens.weight")?; + let wte = Embedding::new(embedding); + let lm_head = Linear::load("lm_head", &vb)?; + let norm = RmsNorm::load("model.norm", &vb)?; + let blocks: Vec<_> = (0..config.n_layer).map(|i| Block::load(&format!("model.layers.{i}"), &vb, cache, config).unwrap()).collect(); + + Ok(Self::new( + wte, + blocks, + norm, + lm_head + )) + } +} + + diff --git a/candle-hub/src/api.rs b/candle-hub/src/api.rs index f4d9c1ec..bdf39cde 100644 --- a/candle-hub/src/api.rs +++ b/candle-hub/src/api.rs @@ -378,6 +378,12 @@ impl Api { let parallel_failures_semaphore = Arc::new(Semaphore::new(self.parallel_failures)); let filename = temp_filename(); + // Create the file and set everything properly + tokio::fs::File::create(&filename) + .await? + .set_len(length as u64) + .await?; + let chunk_size = self.chunk_size; for start in (0..length).step_by(chunk_size) { let url = url.to_string(); @@ -391,6 +397,7 @@ impl Api { let parallel_failures_semaphore = parallel_failures_semaphore.clone(); let progress = progressbar.clone(); handles.push(tokio::spawn(async move { + println!("Start {start:?} - {stop:?}"); let mut chunk = Self::download_chunk(&client, &url, &filename, start, stop).await; let mut i = 0; if parallel_failures > 0 { @@ -440,7 +447,6 @@ impl Api { let range = format!("bytes={start}-{stop}"); let mut file = tokio::fs::OpenOptions::new() .write(true) - .create(true) .open(filename) .await?; file.seek(SeekFrom::Start(start as u64)).await?; diff --git a/candle-hub/src/lib.rs b/candle-hub/src/lib.rs index fbfef5e7..c5965d1f 100644 --- a/candle-hub/src/lib.rs +++ b/candle-hub/src/lib.rs @@ -53,7 +53,11 @@ impl Cache { let commit_hash = std::fs::read_to_string(commit_path).ok()?; let mut pointer_path = self.pointer_path(repo, &commit_hash); pointer_path.push(filename); - Some(pointer_path) + if pointer_path.exists(){ + Some(pointer_path) + }else{ + None + } } /// Creates a reference in the cache directory that points branches to the correct @@ -146,7 +150,12 @@ impl Repo { /// The normalized folder nameof the repo within the cache directory pub fn folder_name(&self) -> String { - self.repo_id.replace('/', "--") + let prefix = match self.repo_type{ + RepoType::Model => "models", + RepoType::Dataset => "datasets", + RepoType::Space => "spaces", + }; + format!("{prefix}--{}", self.repo_id).replace('/', "--") } /// The revision From 926fffa0b7749e59440056b1c328829e33a28fa2 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 27 Jun 2023 17:27:54 +0000 Subject: [PATCH 2/3] Ok. --- candle-core/examples/llama/main.rs | 7 +- candle-core/examples/llama/weights.rs | 132 +++++++++++++++----------- candle-hub/src/api.rs | 1 - candle-hub/src/lib.rs | 6 +- 4 files changed, 84 insertions(+), 62 deletions(-) diff --git a/candle-core/examples/llama/main.rs b/candle-core/examples/llama/main.rs index f1b2c3d4..066025b1 100644 --- a/candle-core/examples/llama/main.rs +++ b/candle-core/examples/llama/main.rs @@ -15,7 +15,7 @@ use anyhow::{Error as E, Result}; use clap::Parser; use candle::{DType, Device, Tensor}; -use candle_hub::{Repo, api::Api, RepoType}; +use candle_hub::{api::Api, Repo, RepoType}; use std::collections::HashMap; use std::sync::{Arc, Mutex}; @@ -439,7 +439,10 @@ async fn main() -> Result<()> { .to_vec(); let mut filenames = vec![]; - for rfilename in ["model-00001-of-00002.safetensors", "model-00002-of-00002.safetensors"]{ + for rfilename in [ + "model-00001-of-00002.safetensors", + "model-00002-of-00002.safetensors", + ] { let filename = api.get(&repo, rfilename).await?; filenames.push(filename); } diff --git a/candle-core/examples/llama/weights.rs b/candle-core/examples/llama/weights.rs index 2a350f65..73609d51 100644 --- a/candle-core/examples/llama/weights.rs +++ b/candle-core/examples/llama/weights.rs @@ -1,13 +1,16 @@ +use super::*; +use candle::{Device, Result, Tensor}; +use half::f16; use memmap2::MmapOptions; -use candle::{Device, Result, Shape, Tensor, WithDType}; +use safetensors::{ + tensor::{Dtype, TensorView}, + SafeTensors, +}; use std::fs::File; use std::path::PathBuf; -use super::*; -use safetensors::{SafeTensors, tensor::{Dtype, TensorView}}; -use half::f16; -fn convert<'a>(view: TensorView<'a>, device: &Device) -> Result{ - match view.dtype(){ +fn convert(view: TensorView<'_>, device: &Device) -> Result { + match view.dtype() { Dtype::F16 => { let v = view.data(); if (v.as_ptr() as usize) % 2 == 0 { @@ -25,76 +28,82 @@ fn convert<'a>(view: TensorView<'a>, device: &Device) -> Result{ } Tensor::from_slice(&c, view.shape(), device) } - } - dt => todo!("Unhandled dtype {dt:?}") + dt => todo!("Unhandled dtype {dt:?}"), } } -pub struct VarBuilder<'a>{ +pub struct VarBuilder<'a> { routing: HashMap, safetensors: Vec>, device: Device, } - -impl<'a> VarBuilder<'a>{ - pub fn new(safetensors: Vec>, device: Device) -> Self{ +impl<'a> VarBuilder<'a> { + pub fn new(safetensors: Vec>, device: Device) -> Self { let mut routing = HashMap::new(); - for (index, sf) in safetensors.iter().enumerate(){ - for k in sf.names(){ + for (index, sf) in safetensors.iter().enumerate() { + for k in sf.names() { routing.insert(k.to_string(), index); } } - Self{ + Self { safetensors, device, - routing + routing, } } - pub fn get(&self, tensor_name: &str) -> Result{ + pub fn get(&self, tensor_name: &str) -> Result { // Unwrap or 0 just to let the proper error flow. let index = self.routing.get(tensor_name).unwrap_or(&0); let view = self.safetensors[*index].tensor(tensor_name).unwrap(); let tensor = convert(view, &self.device)?; Ok(tensor) - } } -impl Linear{ - fn load(prefix: &str, vb: &VarBuilder) -> Result{ +impl Linear { + fn load(prefix: &str, vb: &VarBuilder) -> Result { let weight = vb.get(&format!("{prefix}.weight"))?; Ok(Self::new(weight)) } - fn load_multi(prefixes: &[&str], vb: &VarBuilder) -> Result{ - let weights: Vec<_> = prefixes.iter().map(|p| vb.get(&format!("{p}.weight")).unwrap()).collect(); - println!("shapes {:?}", weights.iter().map(|w| w.shape()).collect::>()); + fn load_multi(prefixes: &[&str], vb: &VarBuilder) -> Result { + let weights: Vec<_> = prefixes + .iter() + .map(|p| vb.get(&format!("{p}.weight")).unwrap()) + .collect(); let weight = Tensor::cat(&weights, 0)?; Ok(Self::new(weight)) } } -impl RmsNorm{ - fn load(prefix: &str, vb: &VarBuilder) -> Result{ +impl RmsNorm { + fn load(prefix: &str, vb: &VarBuilder) -> Result { let scale = vb.get(&format!("{prefix}.weight"))?; Ok(Self::new(scale)) } } -impl CausalSelfAttention{ - fn load(prefix: &str, vb: &VarBuilder, cache: &Cache, config: &Config) -> Result{ - let c_attn = Linear::load_multi(&[&format!("{prefix}.q_proj"), &format!("{prefix}.k_proj"), &format!("{prefix}.v_proj")], vb)?; +impl CausalSelfAttention { + fn load(prefix: &str, vb: &VarBuilder, cache: &Cache, config: &Config) -> Result { + let c_attn = Linear::load_multi( + &[ + &format!("{prefix}.q_proj"), + &format!("{prefix}.k_proj"), + &format!("{prefix}.v_proj"), + ], + vb, + )?; let o_proj = Linear::load(&format!("{prefix}.o_proj"), vb)?; - Ok(Self::new(c_attn,o_proj, config.n_head, cache)) + Ok(Self::new(c_attn, o_proj, config.n_head, cache)) } } -impl Mlp{ - fn load(prefix: &str, vb: &VarBuilder, config: &Config) -> Result{ +impl Mlp { + fn load(prefix: &str, vb: &VarBuilder) -> Result { let c_fc1 = Linear::load(&format!("{prefix}.gate_proj"), vb)?; let c_fc2 = Linear::load(&format!("{prefix}.up_proj"), vb)?; let c_proj = Linear::load(&format!("{prefix}.down_proj"), vb)?; @@ -102,27 +111,43 @@ impl Mlp{ } } -impl Block{ - fn load(prefix: &str, vb: &VarBuilder, cache: &Cache, config: &Config) -> Result{ +impl Block { + fn load(prefix: &str, vb: &VarBuilder, cache: &Cache, config: &Config) -> Result { let attn = CausalSelfAttention::load(&format!("{prefix}.self_attn"), vb, cache, config)?; - let mlp = Mlp::load(&format!("{prefix}.mlp"), vb, config)?; + let mlp = Mlp::load(&format!("{prefix}.mlp"), vb)?; let input_layernorm = RmsNorm::load(&format!("{prefix}.input_layernorm"), vb)?; - let post_attention_layernorm = RmsNorm::load(&format!("{prefix}.post_attention_layernorm"), vb)?; - Ok(Self::new(input_layernorm, attn, post_attention_layernorm, mlp)) + let post_attention_layernorm = + RmsNorm::load(&format!("{prefix}.post_attention_layernorm"), vb)?; + Ok(Self::new( + input_layernorm, + attn, + post_attention_layernorm, + mlp, + )) } } -impl Llama{ - pub fn load(device: &Device, filenames: &[PathBuf], cache: &Cache, config: &Config) -> Result{ - let handles: Vec<_> = filenames.iter().map(|f| { - let file = File::open(f).unwrap(); - let buffer = unsafe { MmapOptions::new().map(&file).unwrap() }; - buffer - }).collect(); - let tensors: Vec<_> = handles.iter().map(|h| { - let tensors = SafeTensors::deserialize(h).unwrap(); - tensors - }).collect(); +impl Llama { + pub fn load( + device: &Device, + filenames: &[PathBuf], + cache: &Cache, + config: &Config, + ) -> Result { + let handles: Vec<_> = filenames + .iter() + .map(|f| { + let file = File::open(f).unwrap(); + unsafe { MmapOptions::new().map(&file).unwrap() } + }) + .collect(); + let tensors: Vec<_> = handles + .iter() + .map(|h| { + let tensors = SafeTensors::deserialize(h).unwrap(); + tensors + }) + .collect(); let vb = VarBuilder::new(tensors, device.clone()); @@ -130,15 +155,10 @@ impl Llama{ let wte = Embedding::new(embedding); let lm_head = Linear::load("lm_head", &vb)?; let norm = RmsNorm::load("model.norm", &vb)?; - let blocks: Vec<_> = (0..config.n_layer).map(|i| Block::load(&format!("model.layers.{i}"), &vb, cache, config).unwrap()).collect(); + let blocks: Vec<_> = (0..config.n_layer) + .map(|i| Block::load(&format!("model.layers.{i}"), &vb, cache, config).unwrap()) + .collect(); - Ok(Self::new( - wte, - blocks, - norm, - lm_head - )) + Ok(Self::new(wte, blocks, norm, lm_head)) } } - - diff --git a/candle-hub/src/api.rs b/candle-hub/src/api.rs index bdf39cde..fa677f24 100644 --- a/candle-hub/src/api.rs +++ b/candle-hub/src/api.rs @@ -397,7 +397,6 @@ impl Api { let parallel_failures_semaphore = parallel_failures_semaphore.clone(); let progress = progressbar.clone(); handles.push(tokio::spawn(async move { - println!("Start {start:?} - {stop:?}"); let mut chunk = Self::download_chunk(&client, &url, &filename, start, stop).await; let mut i = 0; if parallel_failures > 0 { diff --git a/candle-hub/src/lib.rs b/candle-hub/src/lib.rs index c5965d1f..5d186039 100644 --- a/candle-hub/src/lib.rs +++ b/candle-hub/src/lib.rs @@ -53,9 +53,9 @@ impl Cache { let commit_hash = std::fs::read_to_string(commit_path).ok()?; let mut pointer_path = self.pointer_path(repo, &commit_hash); pointer_path.push(filename); - if pointer_path.exists(){ + if pointer_path.exists() { Some(pointer_path) - }else{ + } else { None } } @@ -150,7 +150,7 @@ impl Repo { /// The normalized folder nameof the repo within the cache directory pub fn folder_name(&self) -> String { - let prefix = match self.repo_type{ + let prefix = match self.repo_type { RepoType::Model => "models", RepoType::Dataset => "datasets", RepoType::Space => "spaces", From ece3ec6167220b66a141605deb0a4ffd0136120d Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 28 Jun 2023 12:27:03 +0000 Subject: [PATCH 3/3] Final updates -> moving to deterministic for easier comparison. --- candle-core/examples/llama/main.rs | 64 ++++++++-------- candle-core/examples/llama/var_store.rs | 98 ++++++++++++++++++++++--- 2 files changed, 120 insertions(+), 42 deletions(-) diff --git a/candle-core/examples/llama/main.rs b/candle-core/examples/llama/main.rs index 066025b1..8af465a9 100644 --- a/candle-core/examples/llama/main.rs +++ b/candle-core/examples/llama/main.rs @@ -19,9 +19,7 @@ use candle_hub::{api::Api, Repo, RepoType}; use std::collections::HashMap; use std::sync::{Arc, Mutex}; -// mod var_store; -// use var_store::VarBuilder; - +mod var_store; mod weights; const CONTEXT_SIZE: usize = 512; @@ -196,11 +194,6 @@ fn silu(xs: &Tensor) -> Result { impl Mlp { fn new(c_fc1: Linear, c_fc2: Linear, c_proj: Linear) -> Self { - // let n_hidden = 8 * n_embd / 3; - // let n_hidden = (n_hidden - 1) / 256 * 256 + 256; - // let c_fc1 = Linear::new_no_bias(&vb / "c_fc1", n_embd, n_hidden)?; - // let c_fc2 = Linear::new_no_bias(&vb / "c_fc2", n_embd, n_hidden)?; - // let c_proj = Linear::new_no_bias(&vb / "c_proj", n_hidden, n_embd)?; Self { c_fc1, c_fc2, @@ -244,10 +237,6 @@ impl Cache { let mask: Vec<_> = (0..t) .flat_map(|i| (0..t).map(move |j| u32::from(j > i))) .collect(); - // Once lower_triangle is available, use the followig: - //let mask = Tensor::new(1u32, &device)? - // .broadcast_as(&[t, t])? - // .lower_triangle()? let mask = Tensor::from_slice(&mask, (t, t), &self.device)?; masks.insert(t, mask.clone()); Ok(mask) @@ -265,13 +254,10 @@ struct CausalSelfAttention { impl CausalSelfAttention { fn new(c_attn: Linear, c_proj: Linear, n_head: usize, cache: &Cache) -> Self { - // let c_attn = Linear::new_no_bias(&vb / "c_attn", n_embd, 3 * n_embd)?; - // let c_proj = Linear::new_no_bias(&vb / "c_proj", n_embd, n_embd)?; Self { c_attn, c_proj, n_head, - // n_embd, cache: cache.clone(), } } @@ -408,6 +394,10 @@ struct Args { #[arg(long)] cpu: bool, + /// Use npy instead of safetensors + #[arg(long)] + npy: bool, + /// The temperature used to generate samples. #[arg(long, default_value_t = 1.0)] temperature: f64, @@ -419,7 +409,7 @@ struct Args { #[tokio::main] async fn main() -> Result<()> { - use rand::prelude::*; + //use rand::prelude::*; use tokenizers::Tokenizer; let args = Args::parse(); @@ -431,6 +421,7 @@ async fn main() -> Result<()> { let api = Api::new()?; let repo = Repo::new("Narsil/amall-7b".to_string(), RepoType::Model); let tokenizer_filename = api.get(&repo, "tokenizer.json").await?; + println!("Filename {tokenizer_filename:?}"); let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let mut tokens = tokenizer .encode(START_PROMPT, true) @@ -446,30 +437,25 @@ async fn main() -> Result<()> { let filename = api.get(&repo, rfilename).await?; filenames.push(filename); } - // let weight_path = std::path::Path::new("llama.npz"); - // let weights = if weight_path.exists() { - // println!("loading weights from {weight_path:?}"); - // let start_load = std::time::Instant::now(); - // let tensors = Tensor::read_npz(weight_path)?; - // println!("loaded weights in {:?}", start_load.elapsed()); - // let tensors: std::collections::HashMap = tensors.into_iter().collect(); - // Some(tensors) - // } else { - // println!("cannot find {weight_path:?}, using zero weights"); - // None - // }; - // let vb = VarBuilder::new::(&device, weights); println!("building the model"); let config = Config::config_7b(); let cache = Cache::new(&device); - let llama = Llama::load(&device, &filenames, &cache, &config)?; + let start = std::time::Instant::now(); + let llama = if args.npy { + println!("building the model (NPY)"); + Llama::load_npy(&device, &filenames, &cache, &config)? + } else { + println!("building the model (SF)"); + Llama::load(&device, &filenames, &cache, &config)? + }; + println!("Loaded in {:?}", start.elapsed()); println!("pre-computing the positional embeddings"); let freqs_cis = precompute_freqs_cis(&config, &device)?; println!("starting the inference loop"); let mut new_tokens = vec![]; - let mut rng = thread_rng(); + //let mut rng = thread_rng(); let start_gen = std::time::Instant::now(); for index in 0..args.sample_len { let ctxt = &tokens[tokens.len().saturating_sub(CONTEXT_SIZE)..]; @@ -477,8 +463,20 @@ async fn main() -> Result<()> { let logits = llama.forward(&input, &freqs_cis)?; let prs = (&logits / args.temperature)?.softmax(logits.rank() - 1)?; let logits_v: Vec = prs.to_vec1()?; - let distr = rand::distributions::WeightedIndex::new(&logits_v)?; - let next_token = distr.sample(&mut rng) as u32; + let next_token = logits_v + .iter() + .enumerate() + .fold((0, logits_v[0]), |(idx_max, val_max), (idx, val)| { + if &val_max > val { + (idx_max, val_max) + } else { + (idx, *val) + } + }) + .0 as u32; + // let distr = rand::distributions::WeightedIndex::new(&logits_v)?; + + // let next_token = distr.sample(&mut rng) as u32; tokens.push(next_token); new_tokens.push(next_token); println!( diff --git a/candle-core/examples/llama/var_store.rs b/candle-core/examples/llama/var_store.rs index 1a400edc..8771170e 100644 --- a/candle-core/examples/llama/var_store.rs +++ b/candle-core/examples/llama/var_store.rs @@ -1,5 +1,7 @@ +use super::*; use candle::{DType, Device, Result, Shape, Tensor, WithDType}; use std::collections::HashMap; +use std::path::PathBuf; use std::sync::Arc; #[allow(dead_code)] @@ -40,22 +42,15 @@ impl VarBuilder { self.vars.borrow().len() } - pub fn var>(&mut self, s: &str, shape: S) -> Result { - let shape = shape.into(); + pub fn var(&self, s: &str) -> Result { let path = format!("{}.{s}", self.path.join(".")); - let mut vars = self.vars.borrow_mut(); let parameter = match self.tensors.as_ref() { - None => Tensor::zeros(&shape, self.default_dtype, &self.default_device)?, + None => panic!("Cannot find tensors"), Some(tensors) => match tensors.get(&path) { Some(tensor) => tensor.to_device(&self.default_device)?, None => panic!("cannot find tensor for {path}"), }, }; - vars.push(NamedVar { - path, - dtype: self.default_dtype, - shape, - }); Ok(parameter) } @@ -90,3 +85,88 @@ impl std::ops::Div for VarBuilder { &self / rhs } } + +impl Embedding { + fn load_npy(vb: VarBuilder) -> Result { + let embeddings = vb.var("weight")?; + Ok(Self { embeddings }) + } +} + +impl Linear { + fn load_npy(vb: VarBuilder) -> Result { + let weight = vb.var("weight")?.t()?; + Ok(Self { weight }) + } +} + +impl RmsNorm { + fn load_npy(vb: VarBuilder) -> Result { + let scale = vb.var("scale")?; + Ok(Self::new(scale)) + } +} + +impl CausalSelfAttention { + fn load_npy(vb: VarBuilder, cache: &Cache, config: &Config) -> Result { + let c_attn = Linear::load_npy(&vb / "c_attn")?; + let c_proj = Linear::load_npy(&vb / "c_proj")?; + Ok(Self::new(c_attn, c_proj, config.n_head, cache)) + } +} + +impl Mlp { + fn load_npy(vb: VarBuilder) -> Result { + let c_fc1 = Linear::load_npy(&vb / "c_fc1")?; + let c_fc2 = Linear::load_npy(&vb / "c_fc2")?; + let c_proj = Linear::load_npy(&vb / "c_proj")?; + Ok(Self::new(c_fc1, c_fc2, c_proj)) + } +} + +impl Block { + fn load_npy(vb: VarBuilder, cache: &Cache, config: &Config) -> Result { + let attn = CausalSelfAttention::load_npy(&vb / "attn", cache, config)?; + let mlp = Mlp::load_npy(&vb / "mlp")?; + let input_layernorm = RmsNorm::load_npy(&vb / "rms_1")?; + let post_attention_layernorm = RmsNorm::load_npy(&vb / "rms_2")?; + Ok(Self::new( + input_layernorm, + attn, + post_attention_layernorm, + mlp, + )) + } +} + +impl Llama { + pub fn load_npy( + device: &Device, + _filenames: &[PathBuf], + cache: &Cache, + config: &Config, + ) -> Result { + let weight_path = std::path::Path::new("/data/llama.npz"); + let weights = if weight_path.exists() { + println!("loading weights from {weight_path:?}"); + let start_load = std::time::Instant::now(); + let tensors = Tensor::read_npz(weight_path)?; + println!("loaded weights in {:?}", start_load.elapsed()); + let tensors: std::collections::HashMap = tensors.into_iter().collect(); + Some(tensors) + } else { + println!("cannot find {weight_path:?}, using zero weights"); + None + }; + let vb = VarBuilder::new::(device, weights); + + let wte = Embedding::load_npy(&vb / "transformer" / "wte")?; + let lm_head = Linear::load_npy(&vb / "lm_head")?; + let norm = RmsNorm::load_npy(&vb / "transformer" / "ln_f")?; + let blocks: Vec<_> = (0..config.n_layer) + .map(|i| Block::load_npy(&vb / "transformer" / "h" / i, cache, config).unwrap()) + .collect(); + + Ok(Self::new(wte, blocks, norm, lm_head)) + } +}