mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Add the SGD optimizer (#160)
* Add the nn::optim and some conversion traits. * Add the backward_step function for SGD. * Get the SGD optimizer to work and add a test. * Make the test slighly simpler.
This commit is contained in:
47
candle-nn/src/optim.rs
Normal file
47
candle-nn/src/optim.rs
Normal file
@ -0,0 +1,47 @@
|
||||
//! Various optimization algorithms.
|
||||
use candle::{Result, Tensor, Var};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct SGD {
|
||||
vars: Vec<Var>,
|
||||
learning_rate: f64,
|
||||
}
|
||||
|
||||
impl SGD {
|
||||
pub fn new(vars: &[&Var], learning_rate: f64) -> Self {
|
||||
let vars: Vec<_> = vars.iter().map(|&v| v.clone()).collect();
|
||||
Self {
|
||||
vars,
|
||||
learning_rate,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn empty(learning_rate: f64) -> Self {
|
||||
Self {
|
||||
vars: vec![],
|
||||
learning_rate,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn into_inner(self) -> Vec<Var> {
|
||||
self.vars
|
||||
}
|
||||
|
||||
pub fn learning_rate(&self) -> f64 {
|
||||
self.learning_rate
|
||||
}
|
||||
|
||||
pub fn push(&mut self, var: &Var) {
|
||||
self.vars.push(var.clone())
|
||||
}
|
||||
|
||||
pub fn backward_step(&self, loss: &Tensor) -> Result<()> {
|
||||
let grads = loss.backward()?;
|
||||
for var in self.vars.iter() {
|
||||
if let Some(grad) = grads.get(var) {
|
||||
var.set(&var.sub(&(grad * self.learning_rate)?)?)?
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user