From d210c71d77a6044c2a42c2e75487b6180e957158 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Thu, 31 Aug 2023 09:03:40 +0200 Subject: [PATCH] Set the learning rate. (#687) --- candle-nn/src/optim.rs | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/candle-nn/src/optim.rs b/candle-nn/src/optim.rs index 39f7b34e..b5ac9dba 100644 --- a/candle-nn/src/optim.rs +++ b/candle-nn/src/optim.rs @@ -58,6 +58,10 @@ impl SGD { let grads = loss.backward()?; self.step(&grads) } + + pub fn set_learning_rate(&mut self, lr: f64) { + self.learning_rate = lr + } } #[derive(Clone, Debug)] @@ -127,6 +131,10 @@ impl AdamW { Self::new(vars, params) } + pub fn set_learning_rate(&mut self, lr: f64) { + self.params.lr = lr + } + pub fn step(&mut self, grads: &candle::backprop::GradStore) -> Result<()> { self.step_t += 1; let lr = self.params.lr;