Add the optimizer trait. (#702)

This commit is contained in:
Laurent Mazare
2023-09-01 13:55:39 +02:00
committed by GitHub
parent f2d476ca65
commit 7529531056
6 changed files with 69 additions and 54 deletions

View File

@ -1,6 +1,7 @@
use crate::model::{Cache, Config, Llama}; use crate::model::{Cache, Config, Llama};
use candle::{DType, Device, Result}; use candle::{DType, Device, Result};
use candle_datasets::nlp::tinystories::{Dataset, DatasetRandomIter}; use candle_datasets::nlp::tinystories::{Dataset, DatasetRandomIter};
use candle_nn::Optimizer;
fn valid_loss( fn valid_loss(
dataset: &Dataset, dataset: &Dataset,

View File

@ -9,7 +9,7 @@ use clap::{Parser, ValueEnum};
use rand::prelude::*; use rand::prelude::*;
use candle::{DType, Result, Tensor, D}; use candle::{DType, Result, Tensor, D};
use candle_nn::{loss, ops, Conv2d, Linear, Module, VarBuilder, VarMap}; use candle_nn::{loss, ops, Conv2d, Linear, Module, Optimizer, VarBuilder, VarMap};
const IMAGE_DIM: usize = 784; const IMAGE_DIM: usize = 784;
const LABELS: usize = 10; const LABELS: usize = 10;
@ -190,7 +190,7 @@ fn training_loop<M: Model>(
varmap.load(load)? varmap.load(load)?
} }
let sgd = candle_nn::SGD::new(varmap.all_vars(), args.learning_rate); let mut sgd = candle_nn::SGD::new(varmap.all_vars(), args.learning_rate)?;
let test_images = m.test_images.to_device(&dev)?; let test_images = m.test_images.to_device(&dev)?;
let test_labels = m.test_labels.to_dtype(DType::U32)?.to_device(&dev)?; let test_labels = m.test_labels.to_dtype(DType::U32)?.to_device(&dev)?;
for epoch in 1..args.epochs { for epoch in 1..args.epochs {

View File

@ -5,7 +5,7 @@ extern crate intel_mkl_src;
extern crate accelerate_src; extern crate accelerate_src;
use candle::{DType, Device, Result, Tensor}; use candle::{DType, Device, Result, Tensor};
use candle_nn::{linear, AdamW, Linear, Module, ParamsAdamW, VarBuilder, VarMap}; use candle_nn::{linear, AdamW, Linear, Module, Optimizer, ParamsAdamW, VarBuilder, VarMap};
fn gen_data() -> Result<(Tensor, Tensor)> { fn gen_data() -> Result<(Tensor, Tensor)> {
// Generate some sample linear data. // Generate some sample linear data.

View File

@ -24,7 +24,7 @@ pub use init::Init;
pub use layer_norm::{layer_norm, rms_norm, LayerNorm, LayerNormConfig, RmsNorm}; pub use layer_norm::{layer_norm, rms_norm, LayerNorm, LayerNormConfig, RmsNorm};
pub use linear::{linear, linear_no_bias, Linear}; pub use linear::{linear, linear_no_bias, Linear};
pub use ops::Dropout; pub use ops::Dropout;
pub use optim::{AdamW, ParamsAdamW, SGD}; pub use optim::{AdamW, Optimizer, ParamsAdamW, SGD};
pub use rnn::{gru, lstm, GRUConfig, LSTMConfig, GRU, LSTM, RNN}; pub use rnn::{gru, lstm, GRUConfig, LSTMConfig, GRU, LSTM, RNN};
pub use var_builder::VarBuilder; pub use var_builder::VarBuilder;
pub use var_map::VarMap; pub use var_map::VarMap;

View File

@ -1,6 +1,33 @@
//! Various optimization algorithms. //! Various optimization algorithms.
use candle::{Result, Tensor, Var}; use candle::{Result, Tensor, Var};
/// The interface optimizers should implement.
pub trait Optimizer: Sized {
type Config: Sized;
fn new(vars: Vec<Var>, config: Self::Config) -> Result<Self>;
fn step(&mut self, grads: &candle::backprop::GradStore) -> Result<()>;
fn learning_rate(&self) -> f64;
fn set_learning_rate(&mut self, lr: f64);
fn empty(config: Self::Config) -> Result<Self> {
Self::new(vec![], config)
}
fn backward_step(&mut self, loss: &Tensor) -> Result<()> {
let grads = loss.backward()?;
self.step(&grads)
}
fn from_slice(vars: &[&Var], config: Self::Config) -> Result<Self> {
let vars: Vec<_> = vars.iter().map(|&v| v.clone()).collect();
Self::new(vars, config)
}
}
/// Optimizer for Stochastic Gradient Descent. /// Optimizer for Stochastic Gradient Descent.
/// ///
/// Contrary to the PyTorch implementation of SGD, this version does not support momentum. /// Contrary to the PyTorch implementation of SGD, this version does not support momentum.
@ -10,42 +37,21 @@ pub struct SGD {
learning_rate: f64, learning_rate: f64,
} }
impl SGD { impl Optimizer for SGD {
pub fn from_slice(vars: &[&Var], learning_rate: f64) -> Self { type Config = f64;
let vars: Vec<_> = vars.iter().map(|&v| v.clone()).collect();
Self { fn new(vars: Vec<Var>, learning_rate: f64) -> Result<Self> {
Ok(Self {
vars, vars,
learning_rate, learning_rate,
} })
} }
pub fn new(vars: Vec<Var>, learning_rate: f64) -> Self { fn learning_rate(&self) -> f64 {
Self {
vars,
learning_rate,
}
}
pub fn empty(learning_rate: f64) -> Self {
Self {
vars: vec![],
learning_rate,
}
}
pub fn into_inner(self) -> Vec<Var> {
self.vars
}
pub fn learning_rate(&self) -> f64 {
self.learning_rate self.learning_rate
} }
pub fn push(&mut self, var: &Var) { fn step(&mut self, grads: &candle::backprop::GradStore) -> Result<()> {
self.vars.push(var.clone())
}
pub fn step(&self, grads: &candle::backprop::GradStore) -> Result<()> {
for var in self.vars.iter() { for var in self.vars.iter() {
if let Some(grad) = grads.get(var) { if let Some(grad) = grads.get(var) {
var.set(&var.sub(&(grad * self.learning_rate)?)?)?; var.set(&var.sub(&(grad * self.learning_rate)?)?)?;
@ -54,13 +60,18 @@ impl SGD {
Ok(()) Ok(())
} }
pub fn backward_step(&self, loss: &Tensor) -> Result<()> { fn set_learning_rate(&mut self, lr: f64) {
let grads = loss.backward()?; self.learning_rate = lr
self.step(&grads) }
}
impl SGD {
pub fn into_inner(self) -> Vec<Var> {
self.vars
} }
pub fn set_learning_rate(&mut self, lr: f64) { pub fn push(&mut self, var: &Var) {
self.learning_rate = lr self.vars.push(var.clone())
} }
} }
@ -99,8 +110,10 @@ pub struct AdamW {
params: ParamsAdamW, params: ParamsAdamW,
} }
impl AdamW { impl Optimizer for AdamW {
pub fn new(vars: Vec<Var>, params: ParamsAdamW) -> Result<Self> { type Config = ParamsAdamW;
fn new(vars: Vec<Var>, params: ParamsAdamW) -> Result<Self> {
let vars = vars let vars = vars
.into_iter() .into_iter()
.map(|var| { .map(|var| {
@ -123,19 +136,15 @@ impl AdamW {
}) })
} }
pub fn new_lr(vars: Vec<Var>, learning_rate: f64) -> Result<Self> { fn learning_rate(&self) -> f64 {
let params = ParamsAdamW { self.params.lr
lr: learning_rate,
..ParamsAdamW::default()
};
Self::new(vars, params)
} }
pub fn set_learning_rate(&mut self, lr: f64) { fn set_learning_rate(&mut self, lr: f64) {
self.params.lr = lr self.params.lr = lr
} }
pub fn step(&mut self, grads: &candle::backprop::GradStore) -> Result<()> { fn step(&mut self, grads: &candle::backprop::GradStore) -> Result<()> {
self.step_t += 1; self.step_t += 1;
let lr = self.params.lr; let lr = self.params.lr;
let lambda = self.params.weight_decay; let lambda = self.params.weight_decay;
@ -166,9 +175,14 @@ impl AdamW {
} }
Ok(()) Ok(())
} }
}
pub fn backward_step(&mut self, loss: &Tensor) -> Result<()> { impl AdamW {
let grads = loss.backward()?; pub fn new_lr(vars: Vec<Var>, learning_rate: f64) -> Result<Self> {
self.step(&grads) let params = ParamsAdamW {
lr: learning_rate,
..ParamsAdamW::default()
};
Self::new(vars, params)
} }
} }

View File

@ -8,12 +8,12 @@ use candle::test_utils::{to_vec0_round, to_vec2_round};
use anyhow::Result; use anyhow::Result;
use candle::{Device, Tensor, Var}; use candle::{Device, Tensor, Var};
use candle_nn::{AdamW, Linear, Module, ParamsAdamW, SGD}; use candle_nn::{AdamW, Linear, Module, Optimizer, ParamsAdamW, SGD};
#[test] #[test]
fn sgd_optim() -> Result<()> { fn sgd_optim() -> Result<()> {
let x = Var::new(0f32, &Device::Cpu)?; let x = Var::new(0f32, &Device::Cpu)?;
let sgd = SGD::new(vec![x.clone()], 0.1); let mut sgd = SGD::new(vec![x.clone()], 0.1)?;
let xt = x.as_tensor(); let xt = x.as_tensor();
for _step in 0..100 { for _step in 0..100 {
let loss = ((xt - 4.2)? * (xt - 4.2)?)?; let loss = ((xt - 4.2)? * (xt - 4.2)?)?;
@ -59,7 +59,7 @@ fn sgd_linear_regression() -> Result<()> {
// Now use backprop to run a linear regression between samples and get the coefficients back. // Now use backprop to run a linear regression between samples and get the coefficients back.
let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?; let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
let b = Var::new(0f32, &Device::Cpu)?; let b = Var::new(0f32, &Device::Cpu)?;
let sgd = SGD::new(vec![w.clone(), b.clone()], 0.004); let mut sgd = SGD::new(vec![w.clone(), b.clone()], 0.004)?;
let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone())); let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone()));
for _step in 0..1000 { for _step in 0..1000 {
let ys = lin.forward(&sample_xs)?; let ys = lin.forward(&sample_xs)?;