diff --git a/candle-examples/examples/llama2-c/model.rs b/candle-examples/examples/llama2-c/model.rs index 4e7015dd..77900d27 100644 --- a/candle-examples/examples/llama2-c/model.rs +++ b/candle-examples/examples/llama2-c/model.rs @@ -1,5 +1,6 @@ use candle::{DType, Device, IndexOp, Result, Tensor, D}; -use candle_nn::{Embedding, Linear, VarBuilder}; +use candle_nn::linear_no_bias as linear; +use candle_nn::{embedding, Embedding, Linear, VarBuilder}; use std::collections::HashMap; use std::sync::{Arc, Mutex}; @@ -43,8 +44,25 @@ pub struct Cache { impl Cache { pub fn new(use_kv_cache: bool, cfg: &Config, vb: VarBuilder) -> Result { - let freq_cis_real = vb.get((cfg.seq_len, cfg.head_size() / 2), "freq_cis_real")?; - let freq_cis_imag = vb.get((cfg.seq_len, cfg.head_size() / 2), "freq_cis_imag")?; + let n_elem = cfg.dim / cfg.n_heads; + let theta: Vec<_> = (0..n_elem) + .step_by(2) + .map(|i| 1f32 / 10000f32.powf(i as f32 / n_elem as f32)) + .collect(); + let theta = Tensor::new(theta.as_slice(), vb.device())?; + let idx_theta = Tensor::arange(0, cfg.seq_len as u32, vb.device())? + .to_dtype(DType::F32)? + .reshape((cfg.seq_len, 1))? + .matmul(&theta.reshape((1, theta.elem_count()))?)?; + let precomputed_cos = idx_theta.cos()?; + let precomputed_sin = idx_theta.sin()?; + + let freq_cis_real = vb + .get((cfg.seq_len, cfg.head_size() / 2), "freq_cis_real") + .unwrap_or(precomputed_cos); + let freq_cis_imag = vb + .get((cfg.seq_len, cfg.head_size() / 2), "freq_cis_imag") + .unwrap_or(precomputed_sin); let cos = freq_cis_real.reshape((cfg.seq_len, cfg.head_size() / 2, 1))?; let sin = freq_cis_imag.reshape((cfg.seq_len, cfg.head_size() / 2, 1))?; Ok(Self { @@ -76,16 +94,6 @@ fn silu(xs: &Tensor) -> Result { xs / (xs.neg()?.exp()? + 1.0)? } -fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result { - let weight = vb.get((size2, size1), "weight")?; - Ok(Linear::new(weight, None)) -} - -fn embedding(cfg: &Config, vb: VarBuilder) -> Result { - let embeddings = vb.get((cfg.vocab_size, cfg.dim), "weight")?; - Ok(Embedding::new(embeddings, cfg.dim)) -} - struct RmsNorm { scale: Tensor, eps: f64, @@ -93,7 +101,7 @@ struct RmsNorm { impl RmsNorm { fn load(size: usize, eps: f64, vb: VarBuilder) -> Result { - let scale = vb.get(size, "weight")?; + let scale = vb.get_or_init(size, "weight", candle_nn::Init::Const(1.))?; Ok(Self { scale, eps }) } @@ -315,7 +323,7 @@ impl Llama { } pub fn load(vb: VarBuilder, cache: &Cache, cfg: Config) -> Result { - let wte = embedding(&cfg, vb.pp("model.embed_tokens"))?; + let wte = embedding(cfg.vocab_size, cfg.dim, vb.pp("model.embed_tokens"))?; let lm_head = linear(cfg.dim, cfg.vocab_size, vb.pp("lm_head"))?; let ln_f = RmsNorm::load(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?; let blocks: Vec<_> = (0..cfg.n_layers) diff --git a/candle-examples/examples/llama2-c/training.rs b/candle-examples/examples/llama2-c/training.rs index 196ba9a8..92aa90e6 100644 --- a/candle-examples/examples/llama2-c/training.rs +++ b/candle-examples/examples/llama2-c/training.rs @@ -142,15 +142,15 @@ pub fn run(args: &crate::TrainingCmd, common_args: &crate::Args) -> Result<()> { dataset.train_tokens.len(), dataset.valid_tokens.len() ); - let vb = candle_nn::VarBuilder::zeros(DType::F32, &device); + let varmap = candle_nn::VarMap::new(); + let vb = candle_nn::VarBuilder::from_varmap(&varmap, DType::F32, &device); let config = Config::tiny(); let iter = DatasetRandomIter::new(&dataset, false, config.seq_len, device.clone()); let batch_iter = candle_nn::dataset::Batcher::new_r2(iter).batch_size(args.batch_size); let cache = Cache::new(false, &config, vb.pp("rot"))?; let model = Llama::load(vb, &cache, config)?; - let all_vars = vec![]; // TODO: Propagate the variables from the VarBuilder to here. - let sgd = candle_nn::SGD::new(&all_vars, args.learning_rate); + let sgd = candle_nn::SGD::new(varmap.all_vars(), args.learning_rate); for (batch_index, batch) in batch_iter.enumerate() { let (inp, tgt) = batch?; let logits = model.forward(&inp, 0)?; diff --git a/candle-examples/examples/mnist-training/main.rs b/candle-examples/examples/mnist-training/main.rs index 5bc2e99b..e251f6e9 100644 --- a/candle-examples/examples/mnist-training/main.rs +++ b/candle-examples/examples/mnist-training/main.rs @@ -4,128 +4,20 @@ extern crate intel_mkl_src; use clap::{Parser, ValueEnum}; -use candle::{DType, Device, Result, Shape, Tensor, Var, D}; -use candle_nn::{loss, ops, Init, Linear}; -use std::sync::{Arc, Mutex}; +use candle::{DType, Result, Tensor, D}; +use candle_nn::{loss, ops, Linear, VarBuilder, VarMap}; const IMAGE_DIM: usize = 784; const LABELS: usize = 10; -struct TensorData { - tensors: std::collections::HashMap, - pub dtype: DType, - pub device: Device, -} - -// A variant of candle_nn::VarBuilder for initializing variables before training. -#[derive(Clone)] -struct VarStore { - data: Arc>, - path: Vec, -} - -impl VarStore { - fn new(dtype: DType, device: Device) -> Self { - let data = TensorData { - tensors: std::collections::HashMap::new(), - dtype, - device, - }; - Self { - data: Arc::new(Mutex::new(data)), - path: vec![], - } - } - - fn pp(&self, s: &str) -> Self { - let mut path = self.path.clone(); - path.push(s.to_string()); - Self { - data: self.data.clone(), - path, - } - } - - fn get>(&self, shape: S, tensor_name: &str, init: Init) -> Result { - let shape = shape.into(); - let path = if self.path.is_empty() { - tensor_name.to_string() - } else { - [&self.path.join("."), tensor_name].join(".") - }; - let mut tensor_data = self.data.lock().unwrap(); - if let Some(tensor) = tensor_data.tensors.get(&path) { - let tensor_shape = tensor.shape(); - if &shape != tensor_shape { - candle::bail!("shape mismatch on {path}: {shape:?} <> {tensor_shape:?}") - } - return Ok(tensor.as_tensor().clone()); - } - let var = init.var(shape, tensor_data.dtype, &tensor_data.device)?; - let tensor = var.as_tensor().clone(); - tensor_data.tensors.insert(path, var); - Ok(tensor) - } - - fn all_vars(&self) -> Vec { - let tensor_data = self.data.lock().unwrap(); - #[allow(clippy::map_clone)] - tensor_data - .tensors - .values() - .map(|c| c.clone()) - .collect::>() - } - - fn save>(&self, path: P) -> Result<()> { - let tensor_data = self.data.lock().unwrap(); - let data = tensor_data.tensors.iter().map(|(k, v)| (k, v.as_tensor())); - safetensors::tensor::serialize_to_file(data, &None, path.as_ref())?; - Ok(()) - } - - fn load>(&mut self, path: P) -> Result<()> { - use candle::safetensors::Load; - - let path = path.as_ref(); - let data = unsafe { candle::safetensors::MmapedFile::new(path)? }; - let data = data.deserialize()?; - let mut tensor_data = self.data.lock().unwrap(); - for (name, var) in tensor_data.tensors.iter_mut() { - match data.tensor(name) { - Ok(data) => { - let data: Tensor = data.load(var.device())?; - if let Err(err) = var.set(&data) { - candle::bail!("error setting {name} using data from {path:?}: {err}",) - } - } - Err(_) => candle::bail!("cannot find tensor for {name}"), - } - } - Ok(()) - } -} - -fn linear_z(in_dim: usize, out_dim: usize, vs: VarStore) -> Result { - let ws = vs.get((out_dim, in_dim), "weight", candle_nn::init::ZERO)?; - let bs = vs.get(out_dim, "bias", candle_nn::init::ZERO)?; - Ok(Linear::new(ws, Some(bs))) -} - -fn linear(in_dim: usize, out_dim: usize, vs: VarStore) -> Result { - let init_ws = candle_nn::init::DEFAULT_KAIMING_NORMAL; - let ws = vs.get((out_dim, in_dim), "weight", init_ws)?; - let bound = 1. / (in_dim as f64).sqrt(); - let init_bs = Init::Uniform { - lo: -bound, - up: bound, - }; - let bs = vs.get(out_dim, "bias", init_bs)?; +fn linear_z(in_dim: usize, out_dim: usize, vs: VarBuilder) -> Result { + let ws = vs.get_or_init((out_dim, in_dim), "weight", candle_nn::init::ZERO)?; + let bs = vs.get_or_init(out_dim, "bias", candle_nn::init::ZERO)?; Ok(Linear::new(ws, Some(bs))) } trait Model: Sized { - fn new(vs: VarStore) -> Result; + fn new(vs: VarBuilder) -> Result; fn forward(&self, xs: &Tensor) -> Result; } @@ -134,7 +26,7 @@ struct LinearModel { } impl Model for LinearModel { - fn new(vs: VarStore) -> Result { + fn new(vs: VarBuilder) -> Result { let linear = linear_z(IMAGE_DIM, LABELS, vs)?; Ok(Self { linear }) } @@ -150,9 +42,9 @@ struct Mlp { } impl Model for Mlp { - fn new(vs: VarStore) -> Result { - let ln1 = linear(IMAGE_DIM, 100, vs.pp("ln1"))?; - let ln2 = linear(100, LABELS, vs.pp("ln2"))?; + fn new(vs: VarBuilder) -> Result { + let ln1 = candle_nn::linear(IMAGE_DIM, 100, vs.pp("ln1"))?; + let ln2 = candle_nn::linear(100, LABELS, vs.pp("ln2"))?; Ok(Self { ln1, ln2 }) } @@ -180,17 +72,16 @@ fn training_loop( let train_images = m.train_images.to_device(&dev)?; let train_labels = train_labels.to_dtype(DType::U32)?.to_device(&dev)?; - let mut vs = VarStore::new(DType::F32, dev.clone()); + let mut varmap = VarMap::new(); + let vs = VarBuilder::from_varmap(&varmap, DType::F32, &dev); let model = M::new(vs.clone())?; if let Some(load) = &args.load { println!("loading weights from {load}"); - vs.load(load)? + varmap.load(load)? } - let all_vars = vs.all_vars(); - let all_vars = all_vars.iter().collect::>(); - let sgd = candle_nn::SGD::new(&all_vars, args.learning_rate); + let sgd = candle_nn::SGD::new(varmap.all_vars(), args.learning_rate); let test_images = m.test_images.to_device(&dev)?; let test_labels = m.test_labels.to_dtype(DType::U32)?.to_device(&dev)?; for epoch in 1..args.epochs { @@ -215,7 +106,7 @@ fn training_loop( } if let Some(save) = &args.save { println!("saving trained weights in {save}"); - vs.save(save)? + varmap.save(save)? } Ok(()) } diff --git a/candle-nn/src/embedding.rs b/candle-nn/src/embedding.rs index a0a853b0..050123be 100644 --- a/candle-nn/src/embedding.rs +++ b/candle-nn/src/embedding.rs @@ -28,3 +28,15 @@ impl Embedding { Ok(values) } } + +pub fn embedding(in_size: usize, out_size: usize, vb: crate::VarBuilder) -> Result { + let embeddings = vb.get_or_init( + (in_size, out_size), + "weight", + crate::Init::Randn { + mean: 0., + stdev: 1., + }, + )?; + Ok(Embedding::new(embeddings, out_size)) +} diff --git a/candle-nn/src/layer_norm.rs b/candle-nn/src/layer_norm.rs index 8f8544bb..668f9a4b 100644 --- a/candle-nn/src/layer_norm.rs +++ b/candle-nn/src/layer_norm.rs @@ -62,3 +62,9 @@ impl LayerNorm { Ok(x) } } + +pub fn layer_norm(size: usize, eps: f64, vb: crate::VarBuilder) -> Result { + let weight = vb.get_or_init(size, "weight", crate::Init::Const(1.))?; + let bias = vb.get_or_init(size, "bias", crate::Init::Const(0.))?; + Ok(LayerNorm::new(weight, bias, eps)) +} diff --git a/candle-nn/src/lib.rs b/candle-nn/src/lib.rs index e8086238..45edfc46 100644 --- a/candle-nn/src/lib.rs +++ b/candle-nn/src/lib.rs @@ -15,9 +15,9 @@ pub mod vision; pub use activation::Activation; pub use conv::{Conv1d, Conv1dConfig}; -pub use embedding::Embedding; +pub use embedding::{embedding, Embedding}; pub use init::Init; -pub use layer_norm::LayerNorm; -pub use linear::Linear; +pub use layer_norm::{layer_norm, LayerNorm}; +pub use linear::{linear, linear_no_bias, Linear}; pub use optim::SGD; -pub use var_builder::VarBuilder; +pub use var_builder::{VarBuilder, VarMap}; diff --git a/candle-nn/src/linear.rs b/candle-nn/src/linear.rs index 943011c9..a0bd925a 100644 --- a/candle-nn/src/linear.rs +++ b/candle-nn/src/linear.rs @@ -17,7 +17,7 @@ //! assert_eq!(ys.to_vec2::()?, &[[210.0, 430.0, 650.0]]); //! # Ok(()) } //! ``` -use candle::Tensor; +use candle::{Result, Tensor}; #[derive(Debug)] pub struct Linear { @@ -42,3 +42,24 @@ impl Linear { } } } + +/// Create or initialize a new linear layer. +/// +/// This uses some default names for weight and biases, namely `"weight"` and `"bias"`. +pub fn linear(in_dim: usize, out_dim: usize, vs: crate::VarBuilder) -> Result { + let init_ws = crate::init::DEFAULT_KAIMING_NORMAL; + let ws = vs.get_or_init((out_dim, in_dim), "weight", init_ws)?; + let bound = 1. / (in_dim as f64).sqrt(); + let init_bs = crate::Init::Uniform { + lo: -bound, + up: bound, + }; + let bs = vs.get_or_init(out_dim, "bias", init_bs)?; + Ok(Linear::new(ws, Some(bs))) +} + +pub fn linear_no_bias(in_dim: usize, out_dim: usize, vs: crate::VarBuilder) -> Result { + let init_ws = crate::init::DEFAULT_KAIMING_NORMAL; + let ws = vs.get_or_init((out_dim, in_dim), "weight", init_ws)?; + Ok(Linear::new(ws, None)) +} diff --git a/candle-nn/src/optim.rs b/candle-nn/src/optim.rs index a8b5b370..d20ef284 100644 --- a/candle-nn/src/optim.rs +++ b/candle-nn/src/optim.rs @@ -8,7 +8,7 @@ pub struct SGD { } impl SGD { - pub fn new(vars: &[&Var], learning_rate: f64) -> Self { + pub fn from_slice(vars: &[&Var], learning_rate: f64) -> Self { let vars: Vec<_> = vars.iter().map(|&v| v.clone()).collect(); Self { vars, @@ -16,6 +16,13 @@ impl SGD { } } + pub fn new(vars: Vec, learning_rate: f64) -> Self { + Self { + vars, + learning_rate, + } + } + pub fn empty(learning_rate: f64) -> Self { Self { vars: vec![], diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs index be1380b7..374260b0 100644 --- a/candle-nn/src/var_builder.rs +++ b/candle-nn/src/var_builder.rs @@ -1,7 +1,87 @@ -use candle::{safetensors::Load, DType, Device, Error, Result, Shape, Tensor}; +use candle::{safetensors::Load, DType, Device, Error, Result, Shape, Tensor, Var}; use safetensors::{slice::IndexOp, tensor::SafeTensors}; use std::collections::HashMap; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; + +/// A `VarMap` is a store that holds named variables. Variables can be retrieved from the stores +/// and new variables can be added by providing some initialization config in case they are +/// missing. +/// `VarMap` structures can be serialized in the safetensors format. +#[derive(Clone)] +pub struct VarMap { + data: Arc>>, +} + +impl VarMap { + /// Create a new empty `VarMap`. + #[allow(clippy::new_without_default)] + pub fn new() -> Self { + let data = Arc::new(Mutex::new(HashMap::new())); + Self { data } + } + + /// Retrieve all the variables currently stored in the map. + pub fn all_vars(&self) -> Vec { + let tensor_data = self.data.lock().unwrap(); + #[allow(clippy::map_clone)] + tensor_data.values().map(|c| c.clone()).collect::>() + } + + /// Save the map in the safetensors format. + pub fn save>(&self, path: P) -> Result<()> { + let tensor_data = self.data.lock().unwrap(); + let data = tensor_data.iter().map(|(k, v)| (k, v.as_tensor())); + safetensors::tensor::serialize_to_file(data, &None, path.as_ref())?; + Ok(()) + } + + /// Load some values from a safetensors file and modify the existing variables to have these + /// values. + /// + /// Note that values for variables that are currently not in the map are not kept. + pub fn load>(&mut self, path: P) -> Result<()> { + let path = path.as_ref(); + let data = unsafe { candle::safetensors::MmapedFile::new(path)? }; + let data = data.deserialize()?; + let mut tensor_data = self.data.lock().unwrap(); + for (name, var) in tensor_data.iter_mut() { + match data.tensor(name) { + Ok(data) => { + let data: Tensor = data.load(var.device())?; + if let Err(err) = var.set(&data) { + candle::bail!("error setting {name} using data from {path:?}: {err}",) + } + } + Err(_) => candle::bail!("cannot find tensor for {name}"), + } + } + Ok(()) + } + + /// Retrieve or add a new variable. + pub fn get>( + &self, + shape: S, + path: &str, + init: crate::Init, + dtype: DType, + device: &Device, + ) -> Result { + let shape = shape.into(); + let mut tensor_data = self.data.lock().unwrap(); + if let Some(tensor) = tensor_data.get(path) { + let tensor_shape = tensor.shape(); + if &shape != tensor_shape { + candle::bail!("shape mismatch on {path}: {shape:?} <> {tensor_shape:?}") + } + return Ok(tensor.as_tensor().clone()); + } + let var = init.var(shape, dtype, device)?; + let tensor = var.as_tensor().clone(); + tensor_data.insert(path.to_string(), var); + Ok(tensor) + } +} // TODO: Maybe we would want the storage to be generic, e.g. with Box to avoid too many // generics. @@ -13,6 +93,7 @@ enum Tensors<'a> { Npz(candle::npy::NpzTensors), TensorMap(HashMap), Zeros, + VarMap(VarMap), } struct TensorData<'a> { @@ -64,6 +145,14 @@ impl<'a> TensorData<'a> { dtype, }) } + + fn from_varmap(varmap: &VarMap, dtype: DType, device: &Device) -> Self { + Self { + tensors: Tensors::VarMap(varmap.clone()), + device: device.clone(), + dtype, + } + } } #[derive(Clone)] @@ -99,6 +188,14 @@ impl<'a> VarBuilder<'a> { } } + pub fn from_varmap(varmap: &VarMap, dtype: DType, device: &Device) -> Self { + let data = TensorData::from_varmap(varmap, dtype, device); + Self { + data: Arc::new(data), + path: vec![], + } + } + pub fn from_npz>( file: P, dtype: DType, @@ -154,11 +251,7 @@ impl<'a> VarBuilder<'a> { world_size: usize, ) -> Result { let data = self.data.as_ref(); - let path = if self.path.is_empty() { - tensor_name.to_string() - } else { - [&self.path.join("."), tensor_name].join(".") - }; + let path = self.path(tensor_name); let tensor = match &self.data.tensors { Tensors::SafeTensorWithRouting { routing, @@ -205,19 +298,16 @@ impl<'a> VarBuilder<'a> { let raw: Vec = iterator.into_iter().flatten().cloned().collect(); Tensor::from_raw_buffer(&raw, dtype, &shape, &data.device)? } - _ => unimplemented!(), + _ => candle::bail!("get_sharded is only available for safetensors"), }; Ok(tensor) } + /// Retrieve the tensor associted with the current name and path. pub fn get>(&self, s: S, tensor_name: &str) -> Result { let data = self.data.as_ref(); let s: Shape = s.into(); - let path = if self.path.is_empty() { - tensor_name.to_string() - } else { - [&self.path.join("."), tensor_name].join(".") - }; + let path = self.path(tensor_name); let tensor = match &self.data.tensors { Tensors::Zeros => Tensor::zeros(&s, data.dtype, &data.device)?.contiguous()?, Tensors::TensorMap(ts) => ts @@ -229,6 +319,18 @@ impl<'a> VarBuilder<'a> { .bt() })? .clone(), + Tensors::VarMap(varmap) => { + let data = varmap.data.lock().unwrap(); + data.get(&path) + .ok_or_else(|| { + Error::CannotFindTensor { + path: path.to_string(), + } + .bt() + })? + .as_tensor() + .clone() + } Tensors::Npz(npz) => npz.get(&path)?.ok_or_else(|| { Error::CannotFindTensor { path: path.to_string(), @@ -261,4 +363,32 @@ impl<'a> VarBuilder<'a> { } Ok(tensor) } + + /// Retrieve the tensor associted with the current name and path or initialize a new tensor if + /// it's missing. + /// + /// Tensor initialization is only available if the `VarBuilder` is backed by a `VarMap`. + pub fn get_or_init>( + &self, + s: S, + tensor_name: &str, + init: crate::Init, + ) -> Result { + let data = self.data.as_ref(); + match &self.data.tensors { + Tensors::VarMap(varmap) => { + let path = self.path(tensor_name); + varmap.get(s, &path, init, data.dtype, &data.device) + } + _ => self.get(s, tensor_name), + } + } + + fn path(&self, tensor_name: &str) -> String { + if self.path.is_empty() { + tensor_name.to_string() + } else { + [&self.path.join("."), tensor_name].join(".") + } + } } diff --git a/candle-nn/tests/optim.rs b/candle-nn/tests/optim.rs index 8228e435..54c378cc 100644 --- a/candle-nn/tests/optim.rs +++ b/candle-nn/tests/optim.rs @@ -8,7 +8,7 @@ use candle_nn::{Linear, SGD}; #[test] fn sgd_optim() -> Result<()> { let x = Var::new(0f32, &Device::Cpu)?; - let sgd = SGD::new(&[&x], 0.1); + let sgd = SGD::new(vec![x.clone()], 0.1); let xt = x.as_tensor(); for _step in 0..100 { let loss = ((xt - 4.2)? * (xt - 4.2)?)?; @@ -54,7 +54,7 @@ fn sgd_linear_regression() -> Result<()> { // Now use backprop to run a linear regression between samples and get the coefficients back. let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?; let b = Var::new(0f32, &Device::Cpu)?; - let sgd = SGD::new(&[&w, &b], 0.004); + let sgd = SGD::new(vec![w.clone(), b.clone()], 0.004); let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone())); for _step in 0..1000 { let ys = lin.forward(&sample_xs)?;