mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Add the optimizer trait. (#702)
This commit is contained in:
@ -1,6 +1,33 @@
|
||||
//! Various optimization algorithms.
|
||||
use candle::{Result, Tensor, Var};
|
||||
|
||||
/// The interface optimizers should implement.
|
||||
pub trait Optimizer: Sized {
|
||||
type Config: Sized;
|
||||
|
||||
fn new(vars: Vec<Var>, config: Self::Config) -> Result<Self>;
|
||||
|
||||
fn step(&mut self, grads: &candle::backprop::GradStore) -> Result<()>;
|
||||
|
||||
fn learning_rate(&self) -> f64;
|
||||
|
||||
fn set_learning_rate(&mut self, lr: f64);
|
||||
|
||||
fn empty(config: Self::Config) -> Result<Self> {
|
||||
Self::new(vec![], config)
|
||||
}
|
||||
|
||||
fn backward_step(&mut self, loss: &Tensor) -> Result<()> {
|
||||
let grads = loss.backward()?;
|
||||
self.step(&grads)
|
||||
}
|
||||
|
||||
fn from_slice(vars: &[&Var], config: Self::Config) -> Result<Self> {
|
||||
let vars: Vec<_> = vars.iter().map(|&v| v.clone()).collect();
|
||||
Self::new(vars, config)
|
||||
}
|
||||
}
|
||||
|
||||
/// Optimizer for Stochastic Gradient Descent.
|
||||
///
|
||||
/// Contrary to the PyTorch implementation of SGD, this version does not support momentum.
|
||||
@ -10,42 +37,21 @@ pub struct SGD {
|
||||
learning_rate: f64,
|
||||
}
|
||||
|
||||
impl SGD {
|
||||
pub fn from_slice(vars: &[&Var], learning_rate: f64) -> Self {
|
||||
let vars: Vec<_> = vars.iter().map(|&v| v.clone()).collect();
|
||||
Self {
|
||||
impl Optimizer for SGD {
|
||||
type Config = f64;
|
||||
|
||||
fn new(vars: Vec<Var>, learning_rate: f64) -> Result<Self> {
|
||||
Ok(Self {
|
||||
vars,
|
||||
learning_rate,
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub fn new(vars: Vec<Var>, learning_rate: f64) -> Self {
|
||||
Self {
|
||||
vars,
|
||||
learning_rate,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn empty(learning_rate: f64) -> Self {
|
||||
Self {
|
||||
vars: vec![],
|
||||
learning_rate,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn into_inner(self) -> Vec<Var> {
|
||||
self.vars
|
||||
}
|
||||
|
||||
pub fn learning_rate(&self) -> f64 {
|
||||
fn learning_rate(&self) -> f64 {
|
||||
self.learning_rate
|
||||
}
|
||||
|
||||
pub fn push(&mut self, var: &Var) {
|
||||
self.vars.push(var.clone())
|
||||
}
|
||||
|
||||
pub fn step(&self, grads: &candle::backprop::GradStore) -> Result<()> {
|
||||
fn step(&mut self, grads: &candle::backprop::GradStore) -> Result<()> {
|
||||
for var in self.vars.iter() {
|
||||
if let Some(grad) = grads.get(var) {
|
||||
var.set(&var.sub(&(grad * self.learning_rate)?)?)?;
|
||||
@ -54,13 +60,18 @@ impl SGD {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn backward_step(&self, loss: &Tensor) -> Result<()> {
|
||||
let grads = loss.backward()?;
|
||||
self.step(&grads)
|
||||
fn set_learning_rate(&mut self, lr: f64) {
|
||||
self.learning_rate = lr
|
||||
}
|
||||
}
|
||||
|
||||
impl SGD {
|
||||
pub fn into_inner(self) -> Vec<Var> {
|
||||
self.vars
|
||||
}
|
||||
|
||||
pub fn set_learning_rate(&mut self, lr: f64) {
|
||||
self.learning_rate = lr
|
||||
pub fn push(&mut self, var: &Var) {
|
||||
self.vars.push(var.clone())
|
||||
}
|
||||
}
|
||||
|
||||
@ -99,8 +110,10 @@ pub struct AdamW {
|
||||
params: ParamsAdamW,
|
||||
}
|
||||
|
||||
impl AdamW {
|
||||
pub fn new(vars: Vec<Var>, params: ParamsAdamW) -> Result<Self> {
|
||||
impl Optimizer for AdamW {
|
||||
type Config = ParamsAdamW;
|
||||
|
||||
fn new(vars: Vec<Var>, params: ParamsAdamW) -> Result<Self> {
|
||||
let vars = vars
|
||||
.into_iter()
|
||||
.map(|var| {
|
||||
@ -123,19 +136,15 @@ impl AdamW {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn new_lr(vars: Vec<Var>, learning_rate: f64) -> Result<Self> {
|
||||
let params = ParamsAdamW {
|
||||
lr: learning_rate,
|
||||
..ParamsAdamW::default()
|
||||
};
|
||||
Self::new(vars, params)
|
||||
fn learning_rate(&self) -> f64 {
|
||||
self.params.lr
|
||||
}
|
||||
|
||||
pub fn set_learning_rate(&mut self, lr: f64) {
|
||||
fn set_learning_rate(&mut self, lr: f64) {
|
||||
self.params.lr = lr
|
||||
}
|
||||
|
||||
pub fn step(&mut self, grads: &candle::backprop::GradStore) -> Result<()> {
|
||||
fn step(&mut self, grads: &candle::backprop::GradStore) -> Result<()> {
|
||||
self.step_t += 1;
|
||||
let lr = self.params.lr;
|
||||
let lambda = self.params.weight_decay;
|
||||
@ -166,9 +175,14 @@ impl AdamW {
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub fn backward_step(&mut self, loss: &Tensor) -> Result<()> {
|
||||
let grads = loss.backward()?;
|
||||
self.step(&grads)
|
||||
impl AdamW {
|
||||
pub fn new_lr(vars: Vec<Var>, learning_rate: f64) -> Result<Self> {
|
||||
let params = ParamsAdamW {
|
||||
lr: learning_rate,
|
||||
..ParamsAdamW::default()
|
||||
};
|
||||
Self::new(vars, params)
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user