Add the AdamW optimizer. (#307)

* Add the AdamW optimizer.

* Add some AdamW test validated against PyTorch.
This commit is contained in:
Laurent Mazare
2023-08-02 14:03:49 +01:00
committed by GitHub
parent e2acbe1e72
commit 0902846f25
6 changed files with 216 additions and 19 deletions

View File

@ -1,6 +1,9 @@
//! Various optimization algorithms.
use candle::{Result, Tensor, Var};
/// Optimizer for Stochastic Gradient Descent.
///
/// Contrary to the PyTorch implementation of SGD, this version does not support momentum.
#[derive(Debug)]
pub struct SGD {
vars: Vec<Var>,
@ -42,8 +45,7 @@ impl SGD {
self.vars.push(var.clone())
}
pub fn backward_step(&self, loss: &Tensor) -> Result<()> {
let grads = loss.backward()?;
pub fn step(&self, grads: &candle::backprop::GradStore) -> Result<()> {
for var in self.vars.iter() {
if let Some(grad) = grads.get(var) {
var.set(&var.sub(&(grad * self.learning_rate)?)?)?;
@ -51,4 +53,114 @@ impl SGD {
}
Ok(())
}
pub fn backward_step(&self, loss: &Tensor) -> Result<()> {
let grads = loss.backward()?;
self.step(&grads)
}
}
#[derive(Clone, Debug)]
pub struct ParamsAdamW {
pub lr: f64,
pub beta1: f64,
pub beta2: f64,
pub eps: f64,
pub weight_decay: f64,
}
impl Default for ParamsAdamW {
fn default() -> Self {
Self {
lr: 0.001,
beta1: 0.9,
beta2: 0.999,
eps: 1e-8,
weight_decay: 0.01,
}
}
}
#[derive(Debug)]
struct VarAdamW {
var: Var,
first_moment: Var,
second_moment: Var,
}
#[derive(Debug)]
pub struct AdamW {
vars: Vec<VarAdamW>,
step_t: usize,
params: ParamsAdamW,
}
impl AdamW {
pub fn new(vars: Vec<Var>, params: ParamsAdamW) -> Result<Self> {
let vars = vars
.into_iter()
.map(|var| {
let dtype = var.dtype();
let shape = var.shape();
let device = var.device();
let first_moment = Var::zeros(shape, dtype, device)?;
let second_moment = Var::zeros(shape, dtype, device)?;
Ok(VarAdamW {
var,
first_moment,
second_moment,
})
})
.collect::<Result<Vec<_>>>()?;
Ok(Self {
vars,
params,
step_t: 0,
})
}
pub fn new_lr(vars: Vec<Var>, learning_rate: f64) -> Result<Self> {
let params = ParamsAdamW {
lr: learning_rate,
..ParamsAdamW::default()
};
Self::new(vars, params)
}
pub fn step(&mut self, grads: &candle::backprop::GradStore) -> Result<()> {
self.step_t += 1;
let lr = self.params.lr;
let lambda = self.params.weight_decay;
let lr_lambda = lr * lambda;
let beta1 = self.params.beta1;
let beta2 = self.params.beta2;
let scale_m = 1f64 / (1f64 - beta1.powi(self.step_t as i32));
let scale_v = 1f64 / (1f64 - beta2.powi(self.step_t as i32));
for var in self.vars.iter() {
let theta = &var.var;
let m = &var.first_moment;
let v = &var.second_moment;
if let Some(g) = grads.get(theta) {
// This involves locking 3 RWLocks per params, if the parameters are large this
// should not be an issue but this may be problematic with models with lots of
// small parameters.
let next_m = ((m.as_tensor() * beta1)? + (g * (1.0 - beta1))?)?;
let next_v = ((v.as_tensor() * beta2)? + (g.sqr()? * (1.0 - beta2))?)?;
let m_hat = (&next_m * scale_m)?;
let v_hat = (&next_v * scale_v)?;
let next_theta = (theta.as_tensor() * (1f64 - lr_lambda))?;
let adjusted_grad = (m_hat / (v_hat.sqrt()? + self.params.eps)?)?;
let next_theta = (next_theta - (adjusted_grad * lr)?)?;
m.set(&next_m)?;
v.set(&next_v)?;
theta.set(&next_theta)?;
}
}
Ok(())
}
pub fn backward_step(&mut self, loss: &Tensor) -> Result<()> {
let grads = loss.backward()?;
self.step(&grads)
}
}