mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Add the optimizer trait. (#702)
This commit is contained in:
@ -9,7 +9,7 @@ use clap::{Parser, ValueEnum};
|
||||
use rand::prelude::*;
|
||||
|
||||
use candle::{DType, Result, Tensor, D};
|
||||
use candle_nn::{loss, ops, Conv2d, Linear, Module, VarBuilder, VarMap};
|
||||
use candle_nn::{loss, ops, Conv2d, Linear, Module, Optimizer, VarBuilder, VarMap};
|
||||
|
||||
const IMAGE_DIM: usize = 784;
|
||||
const LABELS: usize = 10;
|
||||
@ -190,7 +190,7 @@ fn training_loop<M: Model>(
|
||||
varmap.load(load)?
|
||||
}
|
||||
|
||||
let sgd = candle_nn::SGD::new(varmap.all_vars(), args.learning_rate);
|
||||
let mut sgd = candle_nn::SGD::new(varmap.all_vars(), args.learning_rate)?;
|
||||
let test_images = m.test_images.to_device(&dev)?;
|
||||
let test_labels = m.test_labels.to_dtype(DType::U32)?.to_device(&dev)?;
|
||||
for epoch in 1..args.epochs {
|
||||
|
Reference in New Issue
Block a user