mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Improve the mnist training example. (#276)
* Improve the mnist training example. * Add some initialization routine that can be used for nn. * Proper initialization in the mnist example.
This commit is contained in:
@ -2,8 +2,10 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
use candle::{DType, Device, Result, Shape, Tensor, Var, D};
|
||||
use candle_nn::{loss, ops, Linear};
|
||||
use candle_nn::{loss, ops, Init, Linear};
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
const IMAGE_DIM: usize = 784;
|
||||
@ -44,7 +46,7 @@ impl VarStore {
|
||||
}
|
||||
}
|
||||
|
||||
fn get<S: Into<Shape>>(&self, shape: S, tensor_name: &str) -> Result<Tensor> {
|
||||
fn get<S: Into<Shape>>(&self, shape: S, tensor_name: &str, init: Init) -> Result<Tensor> {
|
||||
let shape = shape.into();
|
||||
let path = if self.path.is_empty() {
|
||||
tensor_name.to_string()
|
||||
@ -59,8 +61,7 @@ impl VarStore {
|
||||
}
|
||||
return Ok(tensor.as_tensor().clone());
|
||||
}
|
||||
// TODO: Proper initialization using the `Init` enum.
|
||||
let var = Var::zeros(shape, tensor_data.dtype, &tensor_data.device)?;
|
||||
let var = init.var(shape, tensor_data.dtype, &tensor_data.device)?;
|
||||
let tensor = var.as_tensor().clone();
|
||||
tensor_data.tensors.insert(path, var);
|
||||
Ok(tensor)
|
||||
@ -77,21 +78,36 @@ impl VarStore {
|
||||
}
|
||||
}
|
||||
|
||||
fn linear(dim1: usize, dim2: usize, vs: VarStore) -> Result<Linear> {
|
||||
let ws = vs.get((dim2, dim1), "weight")?;
|
||||
let bs = vs.get(dim2, "bias")?;
|
||||
fn linear_z(in_dim: usize, out_dim: usize, vs: VarStore) -> Result<Linear> {
|
||||
let ws = vs.get((out_dim, in_dim), "weight", candle_nn::init::ZERO)?;
|
||||
let bs = vs.get(out_dim, "bias", candle_nn::init::ZERO)?;
|
||||
Ok(Linear::new(ws, Some(bs)))
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
fn linear(in_dim: usize, out_dim: usize, vs: VarStore) -> Result<Linear> {
|
||||
let init_ws = candle_nn::init::DEFAULT_KAIMING_NORMAL;
|
||||
let ws = vs.get((out_dim, in_dim), "weight", init_ws)?;
|
||||
let bound = 1. / (in_dim as f64).sqrt();
|
||||
let init_bs = Init::Uniform {
|
||||
lo: -bound,
|
||||
up: bound,
|
||||
};
|
||||
let bs = vs.get(out_dim, "bias", init_bs)?;
|
||||
Ok(Linear::new(ws, Some(bs)))
|
||||
}
|
||||
|
||||
trait Model: Sized {
|
||||
fn new(vs: VarStore) -> Result<Self>;
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor>;
|
||||
}
|
||||
|
||||
struct LinearModel {
|
||||
linear: Linear,
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
impl LinearModel {
|
||||
impl Model for LinearModel {
|
||||
fn new(vs: VarStore) -> Result<Self> {
|
||||
let linear = linear(IMAGE_DIM, LABELS, vs)?;
|
||||
let linear = linear_z(IMAGE_DIM, LABELS, vs)?;
|
||||
Ok(Self { linear })
|
||||
}
|
||||
|
||||
@ -100,14 +116,12 @@ impl LinearModel {
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
struct Mlp {
|
||||
ln1: Linear,
|
||||
ln2: Linear,
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
impl Mlp {
|
||||
impl Model for Mlp {
|
||||
fn new(vs: VarStore) -> Result<Self> {
|
||||
let ln1 = linear(IMAGE_DIM, 100, vs.pp("ln1"))?;
|
||||
let ln2 = linear(100, LABELS, vs.pp("ln2"))?;
|
||||
@ -121,26 +135,22 @@ impl Mlp {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
fn training_loop<M: Model>(
|
||||
m: candle_nn::vision::Dataset,
|
||||
learning_rate: f64,
|
||||
) -> anyhow::Result<()> {
|
||||
let dev = candle::Device::cuda_if_available(0)?;
|
||||
|
||||
// Load the dataset
|
||||
let m = candle_nn::vision::mnist::load_dir("data")?;
|
||||
println!("train-images: {:?}", m.train_images.shape());
|
||||
println!("train-labels: {:?}", m.train_labels.shape());
|
||||
println!("test-images: {:?}", m.test_images.shape());
|
||||
println!("test-labels: {:?}", m.test_labels.shape());
|
||||
let train_labels = m.train_labels;
|
||||
let train_images = m.train_images;
|
||||
let train_labels = train_labels.to_dtype(DType::U32)?.unsqueeze(1)?;
|
||||
|
||||
let vs = VarStore::new(DType::F32, dev);
|
||||
let model = LinearModel::new(vs.clone())?;
|
||||
// let model = Mlp::new(vs)?;
|
||||
let model = M::new(vs.clone())?;
|
||||
|
||||
let all_vars = vs.all_vars();
|
||||
let all_vars = all_vars.iter().collect::<Vec<_>>();
|
||||
let sgd = candle_nn::SGD::new(&all_vars, 1.0);
|
||||
let sgd = candle_nn::SGD::new(&all_vars, learning_rate);
|
||||
let test_images = m.test_images;
|
||||
let test_labels = m.test_labels.to_dtype(DType::U32)?;
|
||||
for epoch in 1..200 {
|
||||
@ -165,3 +175,33 @@ pub fn main() -> anyhow::Result<()> {
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(ValueEnum, Clone)]
|
||||
enum WhichModel {
|
||||
Linear,
|
||||
Mlp,
|
||||
}
|
||||
|
||||
#[derive(Parser)]
|
||||
struct Args {
|
||||
#[clap(value_enum, default_value_t = WhichModel::Linear)]
|
||||
model: WhichModel,
|
||||
|
||||
#[arg(long)]
|
||||
learning_rate: Option<f64>,
|
||||
}
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
// Load the dataset
|
||||
let m = candle_nn::vision::mnist::load_dir("data")?;
|
||||
println!("train-images: {:?}", m.train_images.shape());
|
||||
println!("train-labels: {:?}", m.train_labels.shape());
|
||||
println!("test-images: {:?}", m.test_images.shape());
|
||||
println!("test-labels: {:?}", m.test_labels.shape());
|
||||
|
||||
match args.model {
|
||||
WhichModel::Linear => training_loop::<LinearModel>(m, args.learning_rate.unwrap_or(1.)),
|
||||
WhichModel::Mlp => training_loop::<Mlp>(m, args.learning_rate.unwrap_or(0.01)),
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user