From 0bb344f7988e55ada3e186247d76936387503017 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 16 Aug 2023 14:39:36 +0200 Subject: [PATCH] [RFC] Start removal of `VarBuilder`. - Uses `Initializer` trait instead. - Allows more decoupling between init and load, which are very different ops. - Allows more decoupling between backends (safetensors, npy, ggml, etc...) This is a minimum viable change. There are 3 kind of objects with various relations. The `Model`: This is `Llama`, `Linear`, `Rms` ... They contain tensors (and possibly other things). and are used to call `forward` basically. They should have no ownership of any internals like Rng state or actual shapes of the tensors (the tensors already own those) The `Initializer`: This is a struct containing necessary information to generate new random tensors. Typically they should own a random generator, and generate different kind of random tensors based on what kind of `Model` they are initializing. This do not own any information about the `Model` itself. Default init stores the `Vec` for now, in order to send to the optimizer. Ths `Config`: This is the necessary information to link between the `Model` and the `Initializer`. This is another struct which is a companion of the implementation of the initalization. Typical information is the shape of the tensors for simple `Model`, the `eps` for RMS, the `use_bias` boolean to know whether we should have a bias in the linear layer. This should remove all need for `VarBuilder` during intialization, and allow removing every initialization bit within `VarBuilder`. Modifying `llama2-c` to follow that initialization is left on purpose for a follow-up to keep the current PR rather small. --- candle-nn/examples/basic_optimizer.rs | 12 ++++--- candle-nn/src/init.rs | 44 ++++++++++++++++++++++++++ candle-nn/src/linear.rs | 45 ++++++++++++++++++++------- 3 files changed, 85 insertions(+), 16 deletions(-) diff --git a/candle-nn/examples/basic_optimizer.rs b/candle-nn/examples/basic_optimizer.rs index 3c5665e8..fc3b7286 100644 --- a/candle-nn/examples/basic_optimizer.rs +++ b/candle-nn/examples/basic_optimizer.rs @@ -1,5 +1,6 @@ use candle::{DType, Device, Result, Tensor}; -use candle_nn::{linear, AdamW, Linear, ParamsAdamW, VarBuilder, VarMap}; +use candle_nn::init::{DefaultInit, ModelInitializer}; +use candle_nn::{AdamW, Linear, ParamsAdamW}; fn gen_data() -> Result<(Tensor, Tensor)> { // Generate some sample linear data. @@ -15,14 +16,15 @@ fn main() -> Result<()> { let (sample_xs, sample_ys) = gen_data()?; // Use backprop to run a linear regression between samples and get the coefficients back. - let varmap = VarMap::new(); - let vb = VarBuilder::from_varmap(&varmap, DType::F32, &Device::Cpu); - let model = linear(2, 1, vb.pp("linear"))?; + // let varmap = VarMap::new(); + // let vb = VarBuilder::from_varmap(&varmap, DType::F32, &Device::Cpu); + let mut initializer = DefaultInit::new(DType::F32, Device::Cpu); + let model = Linear::init(&mut initializer, ((2, 1), true))?; let params = ParamsAdamW { lr: 0.1, ..Default::default() }; - let mut opt = AdamW::new(varmap.all_vars(), params)?; + let mut opt = AdamW::new(initializer.vars().to_vec(), params)?; for step in 0..10000 { let ys = model.forward(&sample_xs)?; let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?; diff --git a/candle-nn/src/init.rs b/candle-nn/src/init.rs index 25702d52..93978964 100644 --- a/candle-nn/src/init.rs +++ b/candle-nn/src/init.rs @@ -3,6 +3,50 @@ // https://github.com/pytorch/pytorch/blob/07107919297db3f8ab37f11c12666b6d6d5f692e/torch/nn/init.py# use candle::{DType, Device, Result, Shape, Tensor, Var}; +pub trait Initializer { + type Config; + fn init(&mut self, config: Self::Config) -> Result; +} + +pub trait ModelInitializer: Sized { + fn init>(initializer: &mut INIT, config: INIT::Config) -> Result { + initializer.init(config) + } +} + +pub struct DefaultInit { + vars: Vec, + dtype: DType, + device: Device, +} + +impl DefaultInit { + pub fn new(dtype: DType, device: Device) -> Self { + let vars = vec![]; + Self { + dtype, + device, + vars, + } + } + + pub fn dtype(&self) -> DType { + self.dtype + } + + pub fn device(&self) -> &Device { + &self.device + } + + pub fn vars(&self) -> &[Var] { + &self.vars + } + + pub fn push_var(&mut self, var: Var) { + self.vars.push(var) + } +} + /// Number of features as input or output of a layer. /// In Kaiming initialization, choosing `FanIn` preserves /// the magnitude of the variance of the weights in the diff --git a/candle-nn/src/linear.rs b/candle-nn/src/linear.rs index a0bd925a..ca4032ba 100644 --- a/candle-nn/src/linear.rs +++ b/candle-nn/src/linear.rs @@ -17,6 +17,7 @@ //! assert_eq!(ys.to_vec2::()?, &[[210.0, 430.0, 650.0]]); //! # Ok(()) } //! ``` +use crate::init::{DefaultInit, Initializer, ModelInitializer}; use candle::{Result, Tensor}; #[derive(Debug)] @@ -43,23 +44,45 @@ impl Linear { } } -/// Create or initialize a new linear layer. +impl ModelInitializer for Linear {} + +impl Initializer for DefaultInit { + type Config = ((usize, usize), bool); + + fn init(&mut self, (shape, has_bias): Self::Config) -> Result { + let dtype = self.dtype(); + let device = self.device().clone(); + let (out_dim, in_dim) = shape; + let init_ws = crate::init::DEFAULT_KAIMING_NORMAL; + let ws = init_ws.var(shape, dtype, &device)?; + self.push_var(ws.clone()); + let ws = ws.as_tensor().clone(); + if has_bias { + let bound = 1. / (in_dim as f64).sqrt(); + let init_bs = crate::Init::Uniform { + lo: -bound, + up: bound, + }; + let bs = init_bs.var(out_dim, dtype, &device)?; + self.push_var(bs.clone()); + let bs = bs.as_tensor().clone(); + Ok(Linear::new(ws, Some(bs))) + } else { + Ok(Linear::new(ws, None)) + } + } +} + +/// Loads a linear layer. /// /// This uses some default names for weight and biases, namely `"weight"` and `"bias"`. pub fn linear(in_dim: usize, out_dim: usize, vs: crate::VarBuilder) -> Result { - let init_ws = crate::init::DEFAULT_KAIMING_NORMAL; - let ws = vs.get_or_init((out_dim, in_dim), "weight", init_ws)?; - let bound = 1. / (in_dim as f64).sqrt(); - let init_bs = crate::Init::Uniform { - lo: -bound, - up: bound, - }; - let bs = vs.get_or_init(out_dim, "bias", init_bs)?; + let ws = vs.get((out_dim, in_dim), "weight")?; + let bs = vs.get(out_dim, "bias")?; Ok(Linear::new(ws, Some(bs))) } pub fn linear_no_bias(in_dim: usize, out_dim: usize, vs: crate::VarBuilder) -> Result { - let init_ws = crate::init::DEFAULT_KAIMING_NORMAL; - let ws = vs.get_or_init((out_dim, in_dim), "weight", init_ws)?; + let ws = vs.get((out_dim, in_dim), "weight")?; Ok(Linear::new(ws, None)) }