Add the optimizer trait. (#702)

This commit is contained in:
Laurent Mazare
2023-09-01 13:55:39 +02:00
committed by GitHub
parent f2d476ca65
commit 7529531056
6 changed files with 69 additions and 54 deletions

View File

@ -8,12 +8,12 @@ use candle::test_utils::{to_vec0_round, to_vec2_round};
use anyhow::Result;
use candle::{Device, Tensor, Var};
use candle_nn::{AdamW, Linear, Module, ParamsAdamW, SGD};
use candle_nn::{AdamW, Linear, Module, Optimizer, ParamsAdamW, SGD};
#[test]
fn sgd_optim() -> Result<()> {
let x = Var::new(0f32, &Device::Cpu)?;
let sgd = SGD::new(vec![x.clone()], 0.1);
let mut sgd = SGD::new(vec![x.clone()], 0.1)?;
let xt = x.as_tensor();
for _step in 0..100 {
let loss = ((xt - 4.2)? * (xt - 4.2)?)?;
@ -59,7 +59,7 @@ fn sgd_linear_regression() -> Result<()> {
// Now use backprop to run a linear regression between samples and get the coefficients back.
let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
let b = Var::new(0f32, &Device::Cpu)?;
let sgd = SGD::new(vec![w.clone(), b.clone()], 0.004);
let mut sgd = SGD::new(vec![w.clone(), b.clone()], 0.004)?;
let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone()));
for _step in 0..1000 {
let ys = lin.forward(&sample_xs)?;