mirror of
https://github.com/huggingface/candle.git
synced 2025-06-22 04:22:50 +00:00
Add the SGD optimizer (#160)
* Add the nn::optim and some conversion traits. * Add the backward_step function for SGD. * Get the SGD optimizer to work and add a test. * Make the test slighly simpler.
This commit is contained in:
19
candle-nn/tests/optim.rs
Normal file
19
candle-nn/tests/optim.rs
Normal file
@ -0,0 +1,19 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
use anyhow::Result;
|
||||
use candle::{Device, Var};
|
||||
use candle_nn::SGD;
|
||||
|
||||
#[test]
|
||||
fn sgd_optim() -> Result<()> {
|
||||
let x = Var::new(0f32, &Device::Cpu)?;
|
||||
let sgd = SGD::new(&[&x], 0.1);
|
||||
let xt = x.as_tensor();
|
||||
for _step in 0..100 {
|
||||
let loss = ((xt - 4.2)? * (xt - 4.2)?)?;
|
||||
sgd.backward_step(&loss)?
|
||||
}
|
||||
assert_eq!(x.to_scalar::<f32>()?, 4.199999);
|
||||
Ok(())
|
||||
}
|
Reference in New Issue
Block a user