mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Set the learning rate. (#687)
This commit is contained in:
@ -58,6 +58,10 @@ impl SGD {
|
|||||||
let grads = loss.backward()?;
|
let grads = loss.backward()?;
|
||||||
self.step(&grads)
|
self.step(&grads)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn set_learning_rate(&mut self, lr: f64) {
|
||||||
|
self.learning_rate = lr
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
@ -127,6 +131,10 @@ impl AdamW {
|
|||||||
Self::new(vars, params)
|
Self::new(vars, params)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn set_learning_rate(&mut self, lr: f64) {
|
||||||
|
self.params.lr = lr
|
||||||
|
}
|
||||||
|
|
||||||
pub fn step(&mut self, grads: &candle::backprop::GradStore) -> Result<()> {
|
pub fn step(&mut self, grads: &candle::backprop::GradStore) -> Result<()> {
|
||||||
self.step_t += 1;
|
self.step_t += 1;
|
||||||
let lr = self.params.lr;
|
let lr = self.params.lr;
|
||||||
|
Reference in New Issue
Block a user