Files
candle/candle-nn/tests/optim.rs
Laurent Mazare ded93a1169 Add the SGD optimizer (#160)
* Add the nn::optim and some conversion traits.

* Add the backward_step function for SGD.

* Get the SGD optimizer to work and add a test.

* Make the test slighly simpler.
2023-07-13 19:05:44 +01:00

20 lines
444 B
Rust

#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
use anyhow::Result;
use candle::{Device, Var};
use candle_nn::SGD;
#[test]
fn sgd_optim() -> Result<()> {
let x = Var::new(0f32, &Device::Cpu)?;
let sgd = SGD::new(&[&x], 0.1);
let xt = x.as_tensor();
for _step in 0..100 {
let loss = ((xt - 4.2)? * (xt - 4.2)?)?;
sgd.backward_step(&loss)?
}
assert_eq!(x.to_scalar::<f32>()?, 4.199999);
Ok(())
}