Set the learning rate. (#687)

This commit is contained in:
Laurent Mazare
2023-08-31 09:03:40 +02:00
committed by GitHub
parent 8e84d8a59b
commit d210c71d77

View File

@ -58,6 +58,10 @@ impl SGD {
let grads = loss.backward()?;
self.step(&grads)
}
pub fn set_learning_rate(&mut self, lr: f64) {
self.learning_rate = lr
}
}
#[derive(Clone, Debug)]
@ -127,6 +131,10 @@ impl AdamW {
Self::new(vars, params)
}
pub fn set_learning_rate(&mut self, lr: f64) {
self.params.lr = lr
}
pub fn step(&mut self, grads: &candle::backprop::GradStore) -> Result<()> {
self.step_t += 1;
let lr = self.params.lr;