diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs index 32b8a122..75cea7ff 100644 --- a/candle-examples/examples/llama/main.rs +++ b/candle-examples/examples/llama/main.rs @@ -20,11 +20,10 @@ use rand::{distributions::Distribution, SeedableRng}; use candle::{DType, Device, Tensor, D}; use candle_hub::{api::sync::Api, Repo, RepoType}; -use std::collections::HashMap; -use std::sync::{Arc, Mutex}; +use candle_nn::VarBuilder; -mod var_store; -mod weights; +mod model; +use model::{Config, Llama}; const MAX_SEQ_LEN: usize = 4096; #[cfg(feature = "mkl")] @@ -83,337 +82,6 @@ Whate'er it bodes, henceforward will I bear Upon my target three fair-shining suns. "; -#[allow(dead_code)] -struct Config { - block_size: usize, - vocab_size: usize, - n_layer: usize, - n_head: usize, - n_embd: usize, -} - -#[allow(dead_code)] -impl Config { - fn config_7b() -> Self { - Self { - block_size: 4096, - vocab_size: 32000, - n_layer: 32, - n_head: 32, - n_embd: 4096, - } - } - - fn config_13b() -> Self { - Self { - block_size: 4096, - vocab_size: 32000, - n_layer: 40, - n_head: 40, - n_embd: 5120, - } - } - - fn config_30b() -> Self { - Self { - block_size: 4096, - vocab_size: 32000, - n_layer: 60, - n_head: 52, - n_embd: 6656, - } - } - - fn config_65b() -> Self { - Self { - block_size: 4096, - vocab_size: 32000, - n_layer: 80, - n_head: 64, - n_embd: 8192, - } - } -} - -struct Embedding { - embeddings: Tensor, -} - -impl Embedding { - fn new(embeddings: Tensor) -> Self { - Self { embeddings } - } - - fn forward(&self, indexes: &Tensor) -> Result { - let embeddings = self.embeddings.to_dtype(DTYPE)?; - Ok(Tensor::embedding(indexes, &embeddings)?) - } -} - -struct Linear { - weight: Tensor, -} - -impl Linear { - fn new(weight: Tensor) -> Self { - Self { weight } - } - - fn forward(&self, x: &Tensor) -> Result { - let weight = self.weight.to_dtype(DTYPE)?; - let x = x.matmul(&weight.t()?)?; - Ok(x) - } -} - -struct RmsNorm { - scale: Tensor, -} - -impl RmsNorm { - fn new(scale: Tensor) -> Self { - Self { scale } - } - - fn forward(&self, x: &Tensor) -> Result { - // This is a no-op if x's dtype is already f32. - let x = x.to_dtype(DType::F32)?; - let (seq_len, hidden_size) = x.shape().r2()?; - let norm_x = ((&x * &x)?.sum(&[1])? / hidden_size as f64)?; - let norm_x = norm_x.broadcast_as((seq_len, hidden_size))?; - let x_normed = (x / (norm_x + 1e-5)?.sqrt()?)?; - let size = self.scale.shape().r1()?; - let scale = self - .scale - .to_dtype(DType::F32)? - .broadcast_as((seq_len, size))?; - let x = (scale * x_normed)?; - let x = x.to_dtype(DTYPE)?; - Ok(x) - } -} - -struct Mlp { - c_fc1: Linear, - c_fc2: Linear, - c_proj: Linear, -} - -fn silu(xs: &Tensor) -> Result { - Ok((xs / (xs.neg()?.exp()? + 1.0)?)?) -} - -impl Mlp { - fn new(c_fc1: Linear, c_fc2: Linear, c_proj: Linear) -> Self { - Self { - c_fc1, - c_fc2, - c_proj, - } - } - - fn forward(&self, x: &Tensor) -> Result { - let x = (silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?; - self.c_proj.forward(&x) - } -} - -fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result { - let shape = mask.shape(); - let on_true = Tensor::new(on_true, &on_false.device())?.broadcast_as(shape.dims())?; - let m = mask.where_cond(&on_true, on_false)?; - Ok(m) -} - -#[derive(Clone)] -struct Cache { - masks: Arc>>, - use_kv_cache: bool, - #[allow(clippy::type_complexity)] - kvs: Arc>>>, - device: Device, -} - -impl Cache { - fn new(use_kv_cache: bool, config: &Config, device: &Device) -> Self { - Self { - masks: Arc::new(Mutex::new(HashMap::new())), - use_kv_cache, - kvs: Arc::new(Mutex::new(vec![None; config.n_layer])), - device: device.clone(), - } - } - - fn mask(&self, t: usize) -> Result { - let mut masks = self.masks.lock().unwrap(); - if let Some(mask) = masks.get(&t) { - Ok(mask.clone()) - } else { - // TODO: If we support bool or u8 tensors, this would be better. - let mask: Vec<_> = (0..t) - .flat_map(|i| (0..t).map(move |j| u32::from(j > i))) - .collect(); - let mask = Tensor::from_slice(&mask, (t, t), &self.device)?; - masks.insert(t, mask.clone()); - Ok(mask) - } - } -} - -struct CausalSelfAttention { - c_attn: Linear, - c_proj: Linear, - n_head: usize, - cache: Cache, -} - -impl CausalSelfAttention { - fn new(c_attn: Linear, c_proj: Linear, n_head: usize, cache: &Cache) -> Self { - Self { - c_attn, - c_proj, - n_head, - cache: cache.clone(), - } - } - - fn apply_rotary_emb(&self, x: &Tensor, freqs_cis: &Tensor) -> Result { - let mut dims = x.dims().to_vec(); - let fcis_dims = freqs_cis.dims(); - let freqs_cis = if dims[1] < fcis_dims[1] { - freqs_cis.narrow(1, 0, dims[1])? - } else { - freqs_cis.clone() - }; - let v = dims.pop().unwrap(); - dims.push(v / 2); - dims.push(2); - let x = x.reshape(dims)?; - let re_x = x.narrow(D::Minus1, 0, 1)?; - let im_x = x.narrow(D::Minus1, 1, 1)?; - let re_f = freqs_cis - .narrow(D::Minus1, 0, 1)? - .broadcast_as(re_x.shape())?; - let im_f = freqs_cis - .narrow(D::Minus1, 1, 1)? - .broadcast_as(im_x.shape())?; - let re = ((&re_x * &re_f)? - (&im_x * &im_f)?)?; - let im = ((&re_x * &im_f)? + (&im_x * &re_f)?)?; - let rope = Tensor::cat(&[&re, &im], D::Minus1)?; - let rope = rope.flatten_from(D::Minus2)?; - Ok(rope) - } - - fn forward(&self, x: &Tensor, freqs_cis: &Tensor, block_idx: usize) -> Result { - let (t, c) = x.shape().r2()?; - let qkv = self.c_attn.forward(x)?; - let qkv = qkv.to_dtype(DType::F32)?; - let n_embd = c; - let q = qkv.narrow(1, 0, n_embd)?; - let k = qkv.narrow(1, n_embd, n_embd)?; - let v = qkv.narrow(1, 2 * n_embd, n_embd)?; - let target_dim = [t, self.n_head, c / self.n_head]; - let k = k.reshape(target_dim.as_slice())?.transpose(0, 1)?; - let q = q.reshape(target_dim.as_slice())?.transpose(0, 1)?; - let mut v = v.reshape(target_dim.as_slice())?.transpose(0, 1)?; - let q = self.apply_rotary_emb(&q, freqs_cis)?; - let mut k = self.apply_rotary_emb(&k, freqs_cis)?; - - if self.cache.use_kv_cache { - let mut cache = self.cache.kvs.lock().unwrap(); - if let Some((cache_k, cache_v)) = &cache[block_idx] { - k = Tensor::cat(&[cache_k, &k], 1)?.contiguous()?; - v = Tensor::cat(&[cache_v, &v], 1)?.contiguous()?; - let k_seq_len = k.dims()[1]; - if k_seq_len > MAX_SEQ_LEN { - k = k - .narrow(1, k_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)? - .contiguous()? - } - let v_seq_len = v.dims()[1]; - if v_seq_len > 2 * MAX_SEQ_LEN { - v = v - .narrow(1, v_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)? - .contiguous()? - } - } - cache[block_idx] = Some((k.clone(), v.clone())) - } - - let att = (q.matmul(&k.t()?)? / (k.dim(D::Minus1)? as f64).sqrt())?; - let mask = self.cache.mask(t)?.broadcast_as(att.shape())?; - let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?; - let att = att.softmax(D::Minus1)?; - // Convert to contiguous as matmul doesn't support strided vs for now. - let y = att.matmul(&v.contiguous()?)?; - let y = y.transpose(0, 1)?.reshape(&[t, c])?; - let y = y.to_dtype(DTYPE)?; - let y = self.c_proj.forward(&y)?; - Ok(y) - } -} - -struct Block { - rms_1: RmsNorm, - attn: CausalSelfAttention, - rms_2: RmsNorm, - mlp: Mlp, -} - -impl Block { - fn new(rms_1: RmsNorm, attn: CausalSelfAttention, rms_2: RmsNorm, mlp: Mlp) -> Self { - Self { - rms_1, - attn, - rms_2, - mlp, - } - } - - fn forward(&self, x: &Tensor, freqs_cis: &Tensor, block_idx: usize) -> Result { - let x = (self - .attn - .forward(&self.rms_1.forward(x)?, freqs_cis, block_idx)? - + x)?; - let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + x)?; - Ok(x) - } -} - -struct Llama { - wte: Embedding, - blocks: Vec, - ln_f: RmsNorm, - lm_head: Linear, -} - -impl Llama { - fn new(wte: Embedding, blocks: Vec, ln_f: RmsNorm, lm_head: Linear) -> Self { - Self { - wte, - blocks, - ln_f, - lm_head, - } - } - - fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result { - // TODO: Support for mini-batches? (i.e. r2) - let t = x.shape().r1()?; - let mut x = self.wte.forward(x)?; - for (block_idx, block) in self.blocks.iter().enumerate() { - x = block.forward(&x, freqs_cis, block_idx)?; - } - let x = self.ln_f.forward(&x)?; - let x = x.narrow(0, t - 1, 1)?; - let logits = self.lm_head.forward(&x)?; - let logits = logits.to_dtype(DType::F32)?; - let (b, vocab_size) = logits.shape().r2()?; - assert_eq!(b, 1); - Ok(logits.reshape(vocab_size)?) - } -} - fn precompute_freqs_cis(config: &Config, device: &Device) -> Result { let n_elem = config.n_embd / config.n_head; let theta: Vec<_> = (0..n_elem) @@ -474,19 +142,15 @@ fn main() -> Result<()> { Device::new_cuda(0)? }; let config = Config::config_7b(); - let cache = Cache::new(!args.no_kv_cache, &config, &device); - let start = std::time::Instant::now(); + let cache = model::Cache::new(!args.no_kv_cache, &config, &device); let (llama, tokenizer_filename) = match args.npy { - Some(npy) => { - println!("building the model (NPY)"); - let weights = Llama::load_npy(&device, &npy, &cache, &config)?; - let token_path = std::path::Path::new("llama-tokenizer.json").to_path_buf(); - (weights, token_path) + Some(_) => { + todo!("fix numpy handling if we continue supporting it") } None => { let api = Api::new()?; let repo = Repo::new("Narsil/amall-7b".to_string(), RepoType::Model); - println!("building the model"); + println!("loading the model weights"); let tokenizer_filename = api.get(&repo, "tokenizer.json")?; let mut filenames = vec![]; for rfilename in [ @@ -497,14 +161,20 @@ fn main() -> Result<()> { filenames.push(filename); } - println!("building the model (SF)"); - ( - Llama::load(&device, &filenames, &cache, &config)?, - tokenizer_filename, - ) + println!("building the model"); + let handles = filenames + .iter() + .map(|f| Ok(unsafe { candle::safetensors::MmapedFile::new(f.as_path())? })) + .collect::>>()?; + let tensors: Vec<_> = handles + .iter() + .map(|h| Ok(h.deserialize()?)) + .collect::>>()?; + + let vb = VarBuilder::from_safetensors(tensors, DTYPE, &device); + (Llama::load(vb, &cache, &config)?, tokenizer_filename) } }; - println!("Loaded in {:?}", start.elapsed()); let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let prompt = args.prompt.as_ref().map_or(DEFAULT_PROMPT, |p| p.as_str()); let mut tokens = tokenizer diff --git a/candle-examples/examples/llama/model.rs b/candle-examples/examples/llama/model.rs new file mode 100644 index 00000000..8ff15564 --- /dev/null +++ b/candle-examples/examples/llama/model.rs @@ -0,0 +1,364 @@ +use candle::{DType, Device, Result, Tensor, D}; +use candle_nn::{Linear, VarBuilder}; +use std::collections::HashMap; +use std::sync::{Arc, Mutex}; + +use super::MAX_SEQ_LEN; + +pub struct Config { + pub hidden_size: usize, + pub intermediate_size: usize, + pub vocab_size: usize, + pub n_layer: usize, + pub n_head: usize, + pub n_embd: usize, +} + +impl Config { + pub fn config_7b() -> Self { + Self { + hidden_size: 4096, + intermediate_size: 11008, + vocab_size: 32000, + n_layer: 32, + n_head: 32, + n_embd: 4096, + } + } +} + +#[derive(Clone)] +pub struct Cache { + masks: Arc>>, + pub use_kv_cache: bool, + #[allow(clippy::type_complexity)] + kvs: Arc>>>, + device: Device, +} + +impl Cache { + pub fn new(use_kv_cache: bool, config: &Config, device: &Device) -> Self { + Self { + masks: Arc::new(Mutex::new(HashMap::new())), + use_kv_cache, + kvs: Arc::new(Mutex::new(vec![None; config.n_layer])), + device: device.clone(), + } + } + + fn mask(&self, t: usize) -> Result { + let mut masks = self.masks.lock().unwrap(); + if let Some(mask) = masks.get(&t) { + Ok(mask.clone()) + } else { + // TODO: If we support bool or u8 tensors, this would be better. + let mask: Vec<_> = (0..t) + .flat_map(|i| (0..t).map(move |j| u32::from(j > i))) + .collect(); + let mask = Tensor::from_slice(&mask, (t, t), &self.device)?; + masks.insert(t, mask.clone()); + Ok(mask) + } + } +} + +fn silu(xs: &Tensor) -> Result { + xs / (xs.neg()?.exp()? + 1.0)? +} + +fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result { + let weight = vb.get((size2, size1), "weight")?; + Ok(Linear::new(weight, None)) +} + +struct Embedding { + embeddings: Tensor, +} + +impl Embedding { + fn new(embeddings: Tensor) -> Self { + Self { embeddings } + } + + fn load(cfg: &Config, vb: VarBuilder) -> Result { + let embeddings = vb.get((cfg.vocab_size, cfg.hidden_size), "weight")?; + Ok(Self::new(embeddings)) + } + + fn forward(&self, indexes: &Tensor) -> Result { + Tensor::embedding(indexes, &self.embeddings) + } +} + +struct RmsNorm { + scale: Tensor, +} + +impl RmsNorm { + fn load(size: usize, vb: VarBuilder) -> Result { + let scale = vb.get(size, "weight")?; + Ok(Self::new(scale)) + } + + fn new(scale: Tensor) -> Self { + Self { scale } + } + + fn forward(&self, x: &Tensor) -> Result { + let in_dtype = x.dtype(); + // This is a no-op if x's dtype is already f32. + let x = x.to_dtype(DType::F32)?; + let (seq_len, hidden_size) = x.shape().r2()?; + let norm_x = ((&x * &x)?.sum(&[1])? / hidden_size as f64)?; + let norm_x = norm_x.broadcast_as((seq_len, hidden_size))?; + let x_normed = (x / (norm_x + 1e-5)?.sqrt()?)?; + let size = self.scale.shape().r1()?; + let scale = self + .scale + .to_dtype(DType::F32)? + .broadcast_as((seq_len, size))?; + let x = (scale * x_normed)?; + let x = x.to_dtype(in_dtype)?; + Ok(x) + } +} + +struct CausalSelfAttention { + c_attn: Linear, + c_proj: Linear, + n_head: usize, + cache: Cache, +} + +impl CausalSelfAttention { + fn new(c_attn: Linear, c_proj: Linear, n_head: usize, cache: &Cache) -> Self { + Self { + c_attn, + c_proj, + n_head, + cache: cache.clone(), + } + } + + fn apply_rotary_emb(&self, x: &Tensor, freqs_cis: &Tensor) -> Result { + let mut dims = x.dims().to_vec(); + let fcis_dims = freqs_cis.dims(); + let freqs_cis = if dims[1] < fcis_dims[1] { + freqs_cis.narrow(1, 0, dims[1])? + } else { + freqs_cis.clone() + }; + let v = dims.pop().unwrap(); + dims.push(v / 2); + dims.push(2); + let x = x.reshape(dims)?; + let re_x = x.narrow(D::Minus1, 0, 1)?; + let im_x = x.narrow(D::Minus1, 1, 1)?; + let re_f = freqs_cis + .narrow(D::Minus1, 0, 1)? + .broadcast_as(re_x.shape())?; + let im_f = freqs_cis + .narrow(D::Minus1, 1, 1)? + .broadcast_as(im_x.shape())?; + let re = ((&re_x * &re_f)? - (&im_x * &im_f)?)?; + let im = ((&re_x * &im_f)? + (&im_x * &re_f)?)?; + let rope = Tensor::cat(&[&re, &im], D::Minus1)?; + let rope = rope.flatten_from(D::Minus2)?; + Ok(rope) + } + + fn forward(&self, x: &Tensor, freqs_cis: &Tensor, block_idx: usize) -> Result { + let x_dtype = x.dtype(); + let (t, c) = x.shape().r2()?; + let qkv = self.c_attn.forward(x)?; + let qkv = qkv.to_dtype(DType::F32)?; + let n_embd = c; + let q = qkv.narrow(1, 0, n_embd)?; + let k = qkv.narrow(1, n_embd, n_embd)?; + let v = qkv.narrow(1, 2 * n_embd, n_embd)?; + let target_dim = [t, self.n_head, c / self.n_head]; + let k = k.reshape(target_dim.as_slice())?.transpose(0, 1)?; + let q = q.reshape(target_dim.as_slice())?.transpose(0, 1)?; + let mut v = v.reshape(target_dim.as_slice())?.transpose(0, 1)?; + let q = self.apply_rotary_emb(&q, freqs_cis)?; + let mut k = self.apply_rotary_emb(&k, freqs_cis)?; + + if self.cache.use_kv_cache { + let mut cache = self.cache.kvs.lock().unwrap(); + if let Some((cache_k, cache_v)) = &cache[block_idx] { + k = Tensor::cat(&[cache_k, &k], 1)?.contiguous()?; + v = Tensor::cat(&[cache_v, &v], 1)?.contiguous()?; + let k_seq_len = k.dims()[1]; + if k_seq_len > MAX_SEQ_LEN { + k = k + .narrow(1, k_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)? + .contiguous()? + } + let v_seq_len = v.dims()[1]; + if v_seq_len > 2 * MAX_SEQ_LEN { + v = v + .narrow(1, v_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)? + .contiguous()? + } + } + cache[block_idx] = Some((k.clone(), v.clone())) + } + + let att = (q.matmul(&k.t()?)? / (k.dim(D::Minus1)? as f64).sqrt())?; + let mask = self.cache.mask(t)?.broadcast_as(att.shape())?; + let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?; + let att = att.softmax(D::Minus1)?; + // Convert to contiguous as matmul doesn't support strided vs for now. + let y = att.matmul(&v.contiguous()?)?; + let y = y.transpose(0, 1)?.reshape(&[t, c])?; + let y = y.to_dtype(x_dtype)?; + let y = self.c_proj.forward(&y)?; + Ok(y) + } + + fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result { + let size_in = cfg.hidden_size; + let size = (cfg.hidden_size / cfg.n_head) * cfg.n_head; + let q_proj = vb.get((size_in, size), "q_proj.weight")?; + let k_proj = vb.get((size_in, size), "k_proj.weight")?; + let v_proj = vb.get((size_in, size), "v_proj.weight")?; + // Invert the transformation from: + // https://github.com/huggingface/transformers/blob/2642d8d04b14c18199ebe7b35f976da02df61752/src/transformers/models/llama/convert_llama_weights_to_hf.py#L101 + let n_head = cfg.n_head; + let q_proj = q_proj + .reshape((n_head, 2, size / n_head / 2, size_in))? + .transpose(1, 2)? + .reshape((size_in, size))?; + let k_proj = k_proj + .reshape((n_head, 2, size / n_head / 2, size_in))? + .transpose(1, 2)? + .reshape((size_in, size))?; + let attn_weight = Tensor::cat(&[q_proj, k_proj, v_proj], 0)?; + let c_attn = Linear::new(attn_weight, None); + let o_proj = linear(size, size_in, vb.pp("o_proj"))?; + Ok(Self::new(c_attn, o_proj, cfg.n_head, cache)) + } +} + +fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result { + let shape = mask.shape(); + let on_true = Tensor::new(on_true, &on_false.device())?.broadcast_as(shape.dims())?; + let m = mask.where_cond(&on_true, on_false)?; + Ok(m) +} + +struct Mlp { + c_fc1: Linear, + c_fc2: Linear, + c_proj: Linear, +} + +impl Mlp { + fn new(c_fc1: Linear, c_fc2: Linear, c_proj: Linear) -> Self { + Self { + c_fc1, + c_fc2, + c_proj, + } + } + + fn forward(&self, x: &Tensor) -> Result { + let x = (silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?; + self.c_proj.forward(&x) + } + + fn load(vb: VarBuilder, cfg: &Config) -> Result { + let h_size = cfg.hidden_size; + let i_size = cfg.intermediate_size; + let c_fc1 = linear(h_size, i_size, vb.pp("gate_proj"))?; + let c_fc2 = linear(h_size, i_size, vb.pp("up_proj"))?; + let c_proj = linear(i_size, h_size, vb.pp("down_proj"))?; + Ok(Self::new(c_fc1, c_fc2, c_proj)) + } +} + +struct Block { + rms_1: RmsNorm, + attn: CausalSelfAttention, + rms_2: RmsNorm, + mlp: Mlp, +} + +impl Block { + fn new(rms_1: RmsNorm, attn: CausalSelfAttention, rms_2: RmsNorm, mlp: Mlp) -> Self { + Self { + rms_1, + attn, + rms_2, + mlp, + } + } + + fn forward(&self, x: &Tensor, freqs_cis: &Tensor, block_idx: usize) -> Result { + let x = (self + .attn + .forward(&self.rms_1.forward(x)?, freqs_cis, block_idx)? + + x)?; + let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + x)?; + Ok(x) + } + + fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result { + let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg)?; + let mlp = Mlp::load(vb.pp("mlp"), cfg)?; + let input_layernorm = RmsNorm::load(cfg.hidden_size, vb.pp("input_layernorm"))?; + let post_attention_layernorm = + RmsNorm::load(cfg.hidden_size, vb.pp("post_attention_layernorm"))?; + Ok(Self::new( + input_layernorm, + attn, + post_attention_layernorm, + mlp, + )) + } +} + +pub struct Llama { + wte: Embedding, + blocks: Vec, + ln_f: RmsNorm, + lm_head: Linear, +} + +impl Llama { + fn new(wte: Embedding, blocks: Vec, ln_f: RmsNorm, lm_head: Linear) -> Self { + Self { + wte, + blocks, + ln_f, + lm_head, + } + } + + pub fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result { + // TODO: Support for mini-batches? (i.e. r2) + let t = x.shape().r1()?; + let mut x = self.wte.forward(x)?; + for (block_idx, block) in self.blocks.iter().enumerate() { + x = block.forward(&x, freqs_cis, block_idx)?; + } + let x = self.ln_f.forward(&x)?; + let x = x.narrow(0, t - 1, 1)?; + let logits = self.lm_head.forward(&x)?; + let logits = logits.to_dtype(DType::F32)?; + let (b, vocab_size) = logits.shape().r2()?; + assert_eq!(b, 1); + logits.reshape(vocab_size) + } + + pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result { + let wte = Embedding::load(cfg, vb.pp("model.embed_tokens"))?; + let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?; + let norm = RmsNorm::load(cfg.hidden_size, vb.pp("model.norm"))?; + let blocks: Vec<_> = (0..cfg.n_layer) + .map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), cache, cfg).unwrap()) + .collect(); + + Ok(Self::new(wte, blocks, norm, lm_head)) + } +} diff --git a/candle-examples/examples/llama/var_store.rs b/candle-examples/examples/llama/var_store.rs deleted file mode 100644 index bd1114a0..00000000 --- a/candle-examples/examples/llama/var_store.rs +++ /dev/null @@ -1,137 +0,0 @@ -use super::*; -use candle::{Device, Result, Tensor}; -use std::collections::HashMap; -use std::sync::{Arc, Mutex}; - -#[derive(Clone)] -pub struct VarBuilder { - path: Vec, - default_device: Device, - tensors: Arc>>, -} - -impl VarBuilder { - pub fn new(device: &Device, tensors: HashMap) -> Self { - Self { - path: vec![], - tensors: Arc::new(Mutex::new(tensors)), - default_device: device.clone(), - } - } - - pub fn get_and_remove(&self, s: &str) -> Result { - let path = format!("{}.{s}", self.path.join(".")); - let mut tensors = self.tensors.as_ref().lock().unwrap(); - let parameter = match tensors.remove(&path) { - Some(tensor) => tensor.to_device(&self.default_device)?, - None => panic!("cannot find tensor for {path}"), - }; - Ok(parameter) - } -} - -impl std::ops::Div for &VarBuilder { - type Output = VarBuilder; - - fn div(self, rhs: S) -> VarBuilder { - let mut path = self.path.clone(); - path.push(rhs.to_string()); - VarBuilder { - path, - default_device: self.default_device.clone(), - tensors: self.tensors.clone(), - } - } -} - -impl std::ops::Div for VarBuilder { - type Output = VarBuilder; - - fn div(self, rhs: S) -> VarBuilder { - &self / rhs - } -} - -impl Embedding { - fn load_npy(vb: VarBuilder) -> Result { - let embeddings = vb.get_and_remove("weight")?.to_dtype(DTYPE)?; - Ok(Self { embeddings }) - } -} - -impl Linear { - fn load_npy(vb: VarBuilder) -> Result { - let weight = vb.get_and_remove("weight")?.to_dtype(DTYPE)?.t()?; - Ok(Self { weight }) - } -} - -impl RmsNorm { - fn load_npy(vb: VarBuilder) -> Result { - let scale = vb.get_and_remove("scale")?.to_dtype(DTYPE)?; - Ok(Self::new(scale)) - } -} - -impl CausalSelfAttention { - fn load_npy(vb: VarBuilder, cache: &Cache, config: &Config) -> Result { - let c_attn = Linear::load_npy(&vb / "c_attn")?; - let c_proj = Linear::load_npy(&vb / "c_proj")?; - Ok(Self::new(c_attn, c_proj, config.n_head, cache)) - } -} - -impl Mlp { - fn load_npy(vb: VarBuilder) -> Result { - let c_fc1 = Linear::load_npy(&vb / "c_fc1")?; - let c_fc2 = Linear::load_npy(&vb / "c_fc2")?; - let c_proj = Linear::load_npy(&vb / "c_proj")?; - Ok(Self::new(c_fc1, c_fc2, c_proj)) - } -} - -impl Block { - fn load_npy(vb: VarBuilder, cache: &Cache, config: &Config) -> Result { - let attn = CausalSelfAttention::load_npy(&vb / "attn", cache, config)?; - let mlp = Mlp::load_npy(&vb / "mlp")?; - let input_layernorm = RmsNorm::load_npy(&vb / "rms_1")?; - let post_attention_layernorm = RmsNorm::load_npy(&vb / "rms_2")?; - Ok(Self::new( - input_layernorm, - attn, - post_attention_layernorm, - mlp, - )) - } -} - -impl Llama { - pub fn load_npy( - device: &Device, - filename: &str, - cache: &Cache, - config: &Config, - ) -> anyhow::Result { - let weight_path = std::path::Path::new(filename); - let weights = if weight_path.exists() { - println!("loading weights from {weight_path:?}"); - let start_load = std::time::Instant::now(); - let tensors = Tensor::read_npz(weight_path)?; - println!("loaded weights in {:?}", start_load.elapsed()); - let tensors: std::collections::HashMap = tensors.into_iter().collect(); - tensors - } else { - anyhow::bail!("cannot find {weight_path:?}") - }; - let vb = VarBuilder::new(device, weights); - - let wte = Embedding::load_npy(&vb / "transformer" / "wte")?; - let lm_head = Linear::load_npy(&vb / "lm_head")?; - let norm = RmsNorm::load_npy(&vb / "transformer" / "ln_f")?; - let blocks: Vec<_> = (0..config.n_layer) - .map(|i| Block::load_npy(&vb / "transformer" / "h" / i, cache, config).unwrap()) - .collect(); - - Ok(Self::new(wte, blocks, norm, lm_head)) - } -} diff --git a/candle-examples/examples/llama/weights.rs b/candle-examples/examples/llama/weights.rs deleted file mode 100644 index c3364cef..00000000 --- a/candle-examples/examples/llama/weights.rs +++ /dev/null @@ -1,127 +0,0 @@ -use super::*; -use candle::{safetensors::SafeTensors, Device, Result, Tensor}; -use std::path::PathBuf; - -pub struct VarBuilder<'a> { - routing: HashMap, - safetensors: Vec>, - device: Device, -} - -impl<'a> VarBuilder<'a> { - pub fn new(safetensors: Vec>, device: Device) -> Self { - let mut routing = HashMap::new(); - for (index, sf) in safetensors.iter().enumerate() { - for k in sf.names() { - routing.insert(k.to_string(), index); - } - } - - Self { - safetensors, - device, - routing, - } - } - - pub fn get(&self, tensor_name: &str) -> Result { - // Unwrap or 0 just to let the proper error flow. - let index = self.routing.get(tensor_name).unwrap_or(&0); - self.safetensors[*index] - .tensor(tensor_name, &self.device)? - .to_dtype(DTYPE) - } -} - -impl Linear { - fn load(prefix: &str, vb: &VarBuilder) -> Result { - let weight = vb.get(&format!("{prefix}.weight"))?; - Ok(Self::new(weight)) - } - - fn load_multi(prefixes: &[&str], vb: &VarBuilder) -> Result { - let weights: Vec<_> = prefixes - .iter() - .map(|p| vb.get(&format!("{p}.weight")).unwrap()) - .collect(); - let weight = Tensor::cat(&weights, 0)?; - Ok(Self::new(weight)) - } -} - -impl RmsNorm { - fn load(prefix: &str, vb: &VarBuilder) -> Result { - let scale = vb.get(&format!("{prefix}.weight"))?; - Ok(Self::new(scale)) - } -} - -impl CausalSelfAttention { - fn load(prefix: &str, vb: &VarBuilder, cache: &Cache, config: &Config) -> Result { - let c_attn = Linear::load_multi( - &[ - &format!("{prefix}.q_proj"), - &format!("{prefix}.k_proj"), - &format!("{prefix}.v_proj"), - ], - vb, - )?; - let o_proj = Linear::load(&format!("{prefix}.o_proj"), vb)?; - Ok(Self::new(c_attn, o_proj, config.n_head, cache)) - } -} - -impl Mlp { - fn load(prefix: &str, vb: &VarBuilder) -> Result { - let c_fc1 = Linear::load(&format!("{prefix}.gate_proj"), vb)?; - let c_fc2 = Linear::load(&format!("{prefix}.up_proj"), vb)?; - let c_proj = Linear::load(&format!("{prefix}.down_proj"), vb)?; - Ok(Self::new(c_fc1, c_fc2, c_proj)) - } -} - -impl Block { - fn load(prefix: &str, vb: &VarBuilder, cache: &Cache, config: &Config) -> Result { - let attn = CausalSelfAttention::load(&format!("{prefix}.self_attn"), vb, cache, config)?; - let mlp = Mlp::load(&format!("{prefix}.mlp"), vb)?; - let input_layernorm = RmsNorm::load(&format!("{prefix}.input_layernorm"), vb)?; - let post_attention_layernorm = - RmsNorm::load(&format!("{prefix}.post_attention_layernorm"), vb)?; - Ok(Self::new( - input_layernorm, - attn, - post_attention_layernorm, - mlp, - )) - } -} - -impl Llama { - pub fn load( - device: &Device, - filenames: &[PathBuf], - cache: &Cache, - config: &Config, - ) -> Result { - let handles: Vec<_> = filenames - .iter() - .map(|f| unsafe { candle::safetensors::MmapedFile::new(f) }) - .collect::>>()?; - let tensors: Vec<_> = handles - .iter() - .map(|h| h.deserialize()) - .collect::>>()?; - - let vb = VarBuilder::new(tensors, device.clone()); - - let embedding = vb.get("model.embed_tokens.weight")?; - let wte = Embedding::new(embedding); - let lm_head = Linear::load("lm_head", &vb)?; - let norm = RmsNorm::load("model.norm", &vb)?; - let blocks: Vec<_> = (0..config.n_layer) - .map(|i| Block::load(&format!("model.layers.{i}"), &vb, cache, config).unwrap()) - .collect(); - - Ok(Self::new(wte, blocks, norm, lm_head)) - } -}