diff --git a/candle-examples/examples/llama2-c/training.rs b/candle-examples/examples/llama2-c/training.rs index 3e93c786..150a3272 100644 --- a/candle-examples/examples/llama2-c/training.rs +++ b/candle-examples/examples/llama2-c/training.rs @@ -1,6 +1,7 @@ use crate::model::{Cache, Config, Llama}; use candle::{DType, Device, Result}; use candle_datasets::nlp::tinystories::{Dataset, DatasetRandomIter}; +use candle_nn::Optimizer; fn valid_loss( dataset: &Dataset, diff --git a/candle-examples/examples/mnist-training/main.rs b/candle-examples/examples/mnist-training/main.rs index 06986681..a07505bf 100644 --- a/candle-examples/examples/mnist-training/main.rs +++ b/candle-examples/examples/mnist-training/main.rs @@ -9,7 +9,7 @@ use clap::{Parser, ValueEnum}; use rand::prelude::*; use candle::{DType, Result, Tensor, D}; -use candle_nn::{loss, ops, Conv2d, Linear, Module, VarBuilder, VarMap}; +use candle_nn::{loss, ops, Conv2d, Linear, Module, Optimizer, VarBuilder, VarMap}; const IMAGE_DIM: usize = 784; const LABELS: usize = 10; @@ -190,7 +190,7 @@ fn training_loop( varmap.load(load)? } - let sgd = candle_nn::SGD::new(varmap.all_vars(), args.learning_rate); + let mut sgd = candle_nn::SGD::new(varmap.all_vars(), args.learning_rate)?; let test_images = m.test_images.to_device(&dev)?; let test_labels = m.test_labels.to_dtype(DType::U32)?.to_device(&dev)?; for epoch in 1..args.epochs { diff --git a/candle-nn/examples/basic_optimizer.rs b/candle-nn/examples/basic_optimizer.rs index 093bda81..810f7a7a 100644 --- a/candle-nn/examples/basic_optimizer.rs +++ b/candle-nn/examples/basic_optimizer.rs @@ -5,7 +5,7 @@ extern crate intel_mkl_src; extern crate accelerate_src; use candle::{DType, Device, Result, Tensor}; -use candle_nn::{linear, AdamW, Linear, Module, ParamsAdamW, VarBuilder, VarMap}; +use candle_nn::{linear, AdamW, Linear, Module, Optimizer, ParamsAdamW, VarBuilder, VarMap}; fn gen_data() -> Result<(Tensor, Tensor)> { // Generate some sample linear data. diff --git a/candle-nn/src/lib.rs b/candle-nn/src/lib.rs index 48046081..6e268f4e 100644 --- a/candle-nn/src/lib.rs +++ b/candle-nn/src/lib.rs @@ -24,7 +24,7 @@ pub use init::Init; pub use layer_norm::{layer_norm, rms_norm, LayerNorm, LayerNormConfig, RmsNorm}; pub use linear::{linear, linear_no_bias, Linear}; pub use ops::Dropout; -pub use optim::{AdamW, ParamsAdamW, SGD}; +pub use optim::{AdamW, Optimizer, ParamsAdamW, SGD}; pub use rnn::{gru, lstm, GRUConfig, LSTMConfig, GRU, LSTM, RNN}; pub use var_builder::VarBuilder; pub use var_map::VarMap; diff --git a/candle-nn/src/optim.rs b/candle-nn/src/optim.rs index b5ac9dba..4294d75e 100644 --- a/candle-nn/src/optim.rs +++ b/candle-nn/src/optim.rs @@ -1,6 +1,33 @@ //! Various optimization algorithms. use candle::{Result, Tensor, Var}; +/// The interface optimizers should implement. +pub trait Optimizer: Sized { + type Config: Sized; + + fn new(vars: Vec, config: Self::Config) -> Result; + + fn step(&mut self, grads: &candle::backprop::GradStore) -> Result<()>; + + fn learning_rate(&self) -> f64; + + fn set_learning_rate(&mut self, lr: f64); + + fn empty(config: Self::Config) -> Result { + Self::new(vec![], config) + } + + fn backward_step(&mut self, loss: &Tensor) -> Result<()> { + let grads = loss.backward()?; + self.step(&grads) + } + + fn from_slice(vars: &[&Var], config: Self::Config) -> Result { + let vars: Vec<_> = vars.iter().map(|&v| v.clone()).collect(); + Self::new(vars, config) + } +} + /// Optimizer for Stochastic Gradient Descent. /// /// Contrary to the PyTorch implementation of SGD, this version does not support momentum. @@ -10,42 +37,21 @@ pub struct SGD { learning_rate: f64, } -impl SGD { - pub fn from_slice(vars: &[&Var], learning_rate: f64) -> Self { - let vars: Vec<_> = vars.iter().map(|&v| v.clone()).collect(); - Self { +impl Optimizer for SGD { + type Config = f64; + + fn new(vars: Vec, learning_rate: f64) -> Result { + Ok(Self { vars, learning_rate, - } + }) } - pub fn new(vars: Vec, learning_rate: f64) -> Self { - Self { - vars, - learning_rate, - } - } - - pub fn empty(learning_rate: f64) -> Self { - Self { - vars: vec![], - learning_rate, - } - } - - pub fn into_inner(self) -> Vec { - self.vars - } - - pub fn learning_rate(&self) -> f64 { + fn learning_rate(&self) -> f64 { self.learning_rate } - pub fn push(&mut self, var: &Var) { - self.vars.push(var.clone()) - } - - pub fn step(&self, grads: &candle::backprop::GradStore) -> Result<()> { + fn step(&mut self, grads: &candle::backprop::GradStore) -> Result<()> { for var in self.vars.iter() { if let Some(grad) = grads.get(var) { var.set(&var.sub(&(grad * self.learning_rate)?)?)?; @@ -54,13 +60,18 @@ impl SGD { Ok(()) } - pub fn backward_step(&self, loss: &Tensor) -> Result<()> { - let grads = loss.backward()?; - self.step(&grads) + fn set_learning_rate(&mut self, lr: f64) { + self.learning_rate = lr + } +} + +impl SGD { + pub fn into_inner(self) -> Vec { + self.vars } - pub fn set_learning_rate(&mut self, lr: f64) { - self.learning_rate = lr + pub fn push(&mut self, var: &Var) { + self.vars.push(var.clone()) } } @@ -99,8 +110,10 @@ pub struct AdamW { params: ParamsAdamW, } -impl AdamW { - pub fn new(vars: Vec, params: ParamsAdamW) -> Result { +impl Optimizer for AdamW { + type Config = ParamsAdamW; + + fn new(vars: Vec, params: ParamsAdamW) -> Result { let vars = vars .into_iter() .map(|var| { @@ -123,19 +136,15 @@ impl AdamW { }) } - pub fn new_lr(vars: Vec, learning_rate: f64) -> Result { - let params = ParamsAdamW { - lr: learning_rate, - ..ParamsAdamW::default() - }; - Self::new(vars, params) + fn learning_rate(&self) -> f64 { + self.params.lr } - pub fn set_learning_rate(&mut self, lr: f64) { + fn set_learning_rate(&mut self, lr: f64) { self.params.lr = lr } - pub fn step(&mut self, grads: &candle::backprop::GradStore) -> Result<()> { + fn step(&mut self, grads: &candle::backprop::GradStore) -> Result<()> { self.step_t += 1; let lr = self.params.lr; let lambda = self.params.weight_decay; @@ -166,9 +175,14 @@ impl AdamW { } Ok(()) } +} - pub fn backward_step(&mut self, loss: &Tensor) -> Result<()> { - let grads = loss.backward()?; - self.step(&grads) +impl AdamW { + pub fn new_lr(vars: Vec, learning_rate: f64) -> Result { + let params = ParamsAdamW { + lr: learning_rate, + ..ParamsAdamW::default() + }; + Self::new(vars, params) } } diff --git a/candle-nn/tests/optim.rs b/candle-nn/tests/optim.rs index 673d0455..841f65c8 100644 --- a/candle-nn/tests/optim.rs +++ b/candle-nn/tests/optim.rs @@ -8,12 +8,12 @@ use candle::test_utils::{to_vec0_round, to_vec2_round}; use anyhow::Result; use candle::{Device, Tensor, Var}; -use candle_nn::{AdamW, Linear, Module, ParamsAdamW, SGD}; +use candle_nn::{AdamW, Linear, Module, Optimizer, ParamsAdamW, SGD}; #[test] fn sgd_optim() -> Result<()> { let x = Var::new(0f32, &Device::Cpu)?; - let sgd = SGD::new(vec![x.clone()], 0.1); + let mut sgd = SGD::new(vec![x.clone()], 0.1)?; let xt = x.as_tensor(); for _step in 0..100 { let loss = ((xt - 4.2)? * (xt - 4.2)?)?; @@ -59,7 +59,7 @@ fn sgd_linear_regression() -> Result<()> { // Now use backprop to run a linear regression between samples and get the coefficients back. let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?; let b = Var::new(0f32, &Device::Cpu)?; - let sgd = SGD::new(vec![w.clone(), b.clone()], 0.004); + let mut sgd = SGD::new(vec![w.clone(), b.clone()], 0.004)?; let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone())); for _step in 0..1000 { let ys = lin.forward(&sample_xs)?;