mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Add the optimizer trait. (#702)
This commit is contained in:
@ -8,12 +8,12 @@ use candle::test_utils::{to_vec0_round, to_vec2_round};
|
||||
|
||||
use anyhow::Result;
|
||||
use candle::{Device, Tensor, Var};
|
||||
use candle_nn::{AdamW, Linear, Module, ParamsAdamW, SGD};
|
||||
use candle_nn::{AdamW, Linear, Module, Optimizer, ParamsAdamW, SGD};
|
||||
|
||||
#[test]
|
||||
fn sgd_optim() -> Result<()> {
|
||||
let x = Var::new(0f32, &Device::Cpu)?;
|
||||
let sgd = SGD::new(vec![x.clone()], 0.1);
|
||||
let mut sgd = SGD::new(vec![x.clone()], 0.1)?;
|
||||
let xt = x.as_tensor();
|
||||
for _step in 0..100 {
|
||||
let loss = ((xt - 4.2)? * (xt - 4.2)?)?;
|
||||
@ -59,7 +59,7 @@ fn sgd_linear_regression() -> Result<()> {
|
||||
// Now use backprop to run a linear regression between samples and get the coefficients back.
|
||||
let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
|
||||
let b = Var::new(0f32, &Device::Cpu)?;
|
||||
let sgd = SGD::new(vec![w.clone(), b.clone()], 0.004);
|
||||
let mut sgd = SGD::new(vec![w.clone(), b.clone()], 0.004)?;
|
||||
let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone()));
|
||||
for _step in 0..1000 {
|
||||
let ys = lin.forward(&sample_xs)?;
|
||||
|
Reference in New Issue
Block a user