Add the SGD optimizer (#160)

* Add the nn::optim and some conversion traits.

* Add the backward_step function for SGD.

* Get the SGD optimizer to work and add a test.

* Make the test slighly simpler.
This commit is contained in:
Laurent Mazare
2023-07-13 19:05:44 +01:00
committed by GitHub
parent 5ee3c95582
commit ded93a1169
6 changed files with 168 additions and 4 deletions

47
candle-nn/src/optim.rs Normal file
View File

@ -0,0 +1,47 @@
//! Various optimization algorithms.
use candle::{Result, Tensor, Var};
#[derive(Debug)]
pub struct SGD {
vars: Vec<Var>,
learning_rate: f64,
}
impl SGD {
pub fn new(vars: &[&Var], learning_rate: f64) -> Self {
let vars: Vec<_> = vars.iter().map(|&v| v.clone()).collect();
Self {
vars,
learning_rate,
}
}
pub fn empty(learning_rate: f64) -> Self {
Self {
vars: vec![],
learning_rate,
}
}
pub fn into_inner(self) -> Vec<Var> {
self.vars
}
pub fn learning_rate(&self) -> f64 {
self.learning_rate
}
pub fn push(&mut self, var: &Var) {
self.vars.push(var.clone())
}
pub fn backward_step(&self, loss: &Tensor) -> Result<()> {
let grads = loss.backward()?;
for var in self.vars.iter() {
if let Some(grad) = grads.get(var) {
var.set(&var.sub(&(grad * self.learning_rate)?)?)?
}
}
Ok(())
}
}