mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Add the AdamW optimizer. (#307)
* Add the AdamW optimizer. * Add some AdamW test validated against PyTorch.
This commit is contained in:
@ -34,7 +34,7 @@
|
|||||||
//! Rust is cool, and a lot of the HF ecosystem already has Rust crates [safetensors](https://github.com/huggingface/safetensors) and [tokenizers](https://github.com/huggingface/tokenizers)
|
//! Rust is cool, and a lot of the HF ecosystem already has Rust crates [safetensors](https://github.com/huggingface/safetensors) and [tokenizers](https://github.com/huggingface/tokenizers)
|
||||||
|
|
||||||
pub mod backend;
|
pub mod backend;
|
||||||
mod backprop;
|
pub mod backprop;
|
||||||
mod conv;
|
mod conv;
|
||||||
mod convert;
|
mod convert;
|
||||||
pub mod cpu_backend;
|
pub mod cpu_backend;
|
||||||
|
@ -19,5 +19,5 @@ pub use embedding::{embedding, Embedding};
|
|||||||
pub use init::Init;
|
pub use init::Init;
|
||||||
pub use layer_norm::{layer_norm, LayerNorm};
|
pub use layer_norm::{layer_norm, LayerNorm};
|
||||||
pub use linear::{linear, linear_no_bias, Linear};
|
pub use linear::{linear, linear_no_bias, Linear};
|
||||||
pub use optim::SGD;
|
pub use optim::{AdamW, ParamsAdamW, SGD};
|
||||||
pub use var_builder::{VarBuilder, VarMap};
|
pub use var_builder::{VarBuilder, VarMap};
|
||||||
|
@ -1,6 +1,9 @@
|
|||||||
//! Various optimization algorithms.
|
//! Various optimization algorithms.
|
||||||
use candle::{Result, Tensor, Var};
|
use candle::{Result, Tensor, Var};
|
||||||
|
|
||||||
|
/// Optimizer for Stochastic Gradient Descent.
|
||||||
|
///
|
||||||
|
/// Contrary to the PyTorch implementation of SGD, this version does not support momentum.
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct SGD {
|
pub struct SGD {
|
||||||
vars: Vec<Var>,
|
vars: Vec<Var>,
|
||||||
@ -42,8 +45,7 @@ impl SGD {
|
|||||||
self.vars.push(var.clone())
|
self.vars.push(var.clone())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn backward_step(&self, loss: &Tensor) -> Result<()> {
|
pub fn step(&self, grads: &candle::backprop::GradStore) -> Result<()> {
|
||||||
let grads = loss.backward()?;
|
|
||||||
for var in self.vars.iter() {
|
for var in self.vars.iter() {
|
||||||
if let Some(grad) = grads.get(var) {
|
if let Some(grad) = grads.get(var) {
|
||||||
var.set(&var.sub(&(grad * self.learning_rate)?)?)?;
|
var.set(&var.sub(&(grad * self.learning_rate)?)?)?;
|
||||||
@ -51,4 +53,114 @@ impl SGD {
|
|||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn backward_step(&self, loss: &Tensor) -> Result<()> {
|
||||||
|
let grads = loss.backward()?;
|
||||||
|
self.step(&grads)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct ParamsAdamW {
|
||||||
|
pub lr: f64,
|
||||||
|
pub beta1: f64,
|
||||||
|
pub beta2: f64,
|
||||||
|
pub eps: f64,
|
||||||
|
pub weight_decay: f64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for ParamsAdamW {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
lr: 0.001,
|
||||||
|
beta1: 0.9,
|
||||||
|
beta2: 0.999,
|
||||||
|
eps: 1e-8,
|
||||||
|
weight_decay: 0.01,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct VarAdamW {
|
||||||
|
var: Var,
|
||||||
|
first_moment: Var,
|
||||||
|
second_moment: Var,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct AdamW {
|
||||||
|
vars: Vec<VarAdamW>,
|
||||||
|
step_t: usize,
|
||||||
|
params: ParamsAdamW,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AdamW {
|
||||||
|
pub fn new(vars: Vec<Var>, params: ParamsAdamW) -> Result<Self> {
|
||||||
|
let vars = vars
|
||||||
|
.into_iter()
|
||||||
|
.map(|var| {
|
||||||
|
let dtype = var.dtype();
|
||||||
|
let shape = var.shape();
|
||||||
|
let device = var.device();
|
||||||
|
let first_moment = Var::zeros(shape, dtype, device)?;
|
||||||
|
let second_moment = Var::zeros(shape, dtype, device)?;
|
||||||
|
Ok(VarAdamW {
|
||||||
|
var,
|
||||||
|
first_moment,
|
||||||
|
second_moment,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.collect::<Result<Vec<_>>>()?;
|
||||||
|
Ok(Self {
|
||||||
|
vars,
|
||||||
|
params,
|
||||||
|
step_t: 0,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn new_lr(vars: Vec<Var>, learning_rate: f64) -> Result<Self> {
|
||||||
|
let params = ParamsAdamW {
|
||||||
|
lr: learning_rate,
|
||||||
|
..ParamsAdamW::default()
|
||||||
|
};
|
||||||
|
Self::new(vars, params)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn step(&mut self, grads: &candle::backprop::GradStore) -> Result<()> {
|
||||||
|
self.step_t += 1;
|
||||||
|
let lr = self.params.lr;
|
||||||
|
let lambda = self.params.weight_decay;
|
||||||
|
let lr_lambda = lr * lambda;
|
||||||
|
let beta1 = self.params.beta1;
|
||||||
|
let beta2 = self.params.beta2;
|
||||||
|
let scale_m = 1f64 / (1f64 - beta1.powi(self.step_t as i32));
|
||||||
|
let scale_v = 1f64 / (1f64 - beta2.powi(self.step_t as i32));
|
||||||
|
for var in self.vars.iter() {
|
||||||
|
let theta = &var.var;
|
||||||
|
let m = &var.first_moment;
|
||||||
|
let v = &var.second_moment;
|
||||||
|
if let Some(g) = grads.get(theta) {
|
||||||
|
// This involves locking 3 RWLocks per params, if the parameters are large this
|
||||||
|
// should not be an issue but this may be problematic with models with lots of
|
||||||
|
// small parameters.
|
||||||
|
let next_m = ((m.as_tensor() * beta1)? + (g * (1.0 - beta1))?)?;
|
||||||
|
let next_v = ((v.as_tensor() * beta2)? + (g.sqr()? * (1.0 - beta2))?)?;
|
||||||
|
let m_hat = (&next_m * scale_m)?;
|
||||||
|
let v_hat = (&next_v * scale_v)?;
|
||||||
|
let next_theta = (theta.as_tensor() * (1f64 - lr_lambda))?;
|
||||||
|
let adjusted_grad = (m_hat / (v_hat.sqrt()? + self.params.eps)?)?;
|
||||||
|
let next_theta = (next_theta - (adjusted_grad * lr)?)?;
|
||||||
|
m.set(&next_m)?;
|
||||||
|
v.set(&next_v)?;
|
||||||
|
theta.set(&next_theta)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn backward_step(&mut self, loss: &Tensor) -> Result<()> {
|
||||||
|
let grads = loss.backward()?;
|
||||||
|
self.step(&grads)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,18 +1,10 @@
|
|||||||
use candle::{Device, Result, Tensor};
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
pub fn to_vec3_round(t: Tensor, digits: i32) -> Result<Vec<Vec<Vec<f32>>>> {
|
mod test_utils;
|
||||||
let b = 10f32.powi(digits);
|
use test_utils::to_vec3_round;
|
||||||
let t = t.to_vec3::<f32>()?;
|
|
||||||
let t = t
|
use candle::{Device, Result, Tensor};
|
||||||
.iter()
|
|
||||||
.map(|t| {
|
|
||||||
t.iter()
|
|
||||||
.map(|t| t.iter().map(|t| f32::round(t * b) / b).collect())
|
|
||||||
.collect()
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
Ok(t)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn softmax() -> Result<()> {
|
fn softmax() -> Result<()> {
|
||||||
|
@ -1,9 +1,12 @@
|
|||||||
#[cfg(feature = "mkl")]
|
#[cfg(feature = "mkl")]
|
||||||
extern crate intel_mkl_src;
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
mod test_utils;
|
||||||
|
use test_utils::{to_vec0_round, to_vec2_round};
|
||||||
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use candle::{Device, Tensor, Var};
|
use candle::{Device, Tensor, Var};
|
||||||
use candle_nn::{Linear, SGD};
|
use candle_nn::{AdamW, Linear, ParamsAdamW, SGD};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn sgd_optim() -> Result<()> {
|
fn sgd_optim() -> Result<()> {
|
||||||
@ -65,3 +68,54 @@ fn sgd_linear_regression() -> Result<()> {
|
|||||||
assert_eq!(b.to_scalar::<f32>()?, -1.9796902);
|
assert_eq!(b.to_scalar::<f32>()?, -1.9796902);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* The following test returns the same values as the PyTorch code below.
|
||||||
|
import torch
|
||||||
|
from torch import optim
|
||||||
|
|
||||||
|
w_gen = torch.tensor([[3., 1.]])
|
||||||
|
b_gen = torch.tensor([-2.])
|
||||||
|
|
||||||
|
sample_xs = torch.tensor([[2., 1.], [7., 4.], [-4., 12.], [5., 8.]])
|
||||||
|
sample_ys = sample_xs.matmul(w_gen.t()) + b_gen
|
||||||
|
|
||||||
|
m = torch.nn.Linear(2, 1)
|
||||||
|
with torch.no_grad():
|
||||||
|
m.weight.zero_()
|
||||||
|
m.bias.zero_()
|
||||||
|
optimizer = optim.AdamW(m.parameters(), lr=0.1)
|
||||||
|
for _step in range(100):
|
||||||
|
optimizer.zero_grad()
|
||||||
|
ys = m(sample_xs)
|
||||||
|
loss = ((ys - sample_ys)**2).sum()
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
print(m.weight)
|
||||||
|
print(m.bias)
|
||||||
|
*/
|
||||||
|
#[test]
|
||||||
|
fn adamw_linear_regression() -> Result<()> {
|
||||||
|
let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?;
|
||||||
|
let b_gen = Tensor::new(-2f32, &Device::Cpu)?;
|
||||||
|
let gen = Linear::new(w_gen, Some(b_gen));
|
||||||
|
let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?;
|
||||||
|
let sample_ys = gen.forward(&sample_xs)?;
|
||||||
|
|
||||||
|
// Now use backprop to run a linear regression between samples and get the coefficients back.
|
||||||
|
let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
|
||||||
|
let b = Var::new(0f32, &Device::Cpu)?;
|
||||||
|
let params = ParamsAdamW {
|
||||||
|
lr: 0.1,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let mut opt = AdamW::new(vec![w.clone(), b.clone()], params)?;
|
||||||
|
let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone()));
|
||||||
|
for _step in 0..100 {
|
||||||
|
let ys = lin.forward(&sample_xs)?;
|
||||||
|
let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?;
|
||||||
|
opt.backward_step(&loss)?;
|
||||||
|
}
|
||||||
|
assert_eq!(to_vec2_round(w.as_tensor(), 4)?, &[[2.7257, 0.7097]]);
|
||||||
|
assert_eq!(to_vec0_round(b.as_tensor(), 4)?, 0.7873);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
39
candle-nn/tests/test_utils.rs
Normal file
39
candle-nn/tests/test_utils.rs
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
#![allow(dead_code)]
|
||||||
|
use candle::{Result, Tensor};
|
||||||
|
|
||||||
|
pub fn to_vec0_round(t: &Tensor, digits: i32) -> Result<f32> {
|
||||||
|
let b = 10f32.powi(digits);
|
||||||
|
let t = t.to_vec0::<f32>()?;
|
||||||
|
Ok(f32::round(t * b) / b)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn to_vec1_round(t: &Tensor, digits: i32) -> Result<Vec<f32>> {
|
||||||
|
let b = 10f32.powi(digits);
|
||||||
|
let t = t.to_vec1::<f32>()?;
|
||||||
|
let t = t.iter().map(|t| f32::round(t * b) / b).collect();
|
||||||
|
Ok(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn to_vec2_round(t: &Tensor, digits: i32) -> Result<Vec<Vec<f32>>> {
|
||||||
|
let b = 10f32.powi(digits);
|
||||||
|
let t = t.to_vec2::<f32>()?;
|
||||||
|
let t = t
|
||||||
|
.iter()
|
||||||
|
.map(|t| t.iter().map(|t| f32::round(t * b) / b).collect())
|
||||||
|
.collect();
|
||||||
|
Ok(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn to_vec3_round(t: Tensor, digits: i32) -> Result<Vec<Vec<Vec<f32>>>> {
|
||||||
|
let b = 10f32.powi(digits);
|
||||||
|
let t = t.to_vec3::<f32>()?;
|
||||||
|
let t = t
|
||||||
|
.iter()
|
||||||
|
.map(|t| {
|
||||||
|
t.iter()
|
||||||
|
.map(|t| t.iter().map(|t| f32::round(t * b) / b).collect())
|
||||||
|
.collect()
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
Ok(t)
|
||||||
|
}
|
Reference in New Issue
Block a user