Improve the mnist training example. (#276)

* Improve the mnist training example.

* Add some initialization routine that can be used for nn.

* Proper initialization in the mnist example.
This commit is contained in:
Laurent Mazare
2023-07-29 16:28:22 +01:00
committed by GitHub
parent bedcef64dc
commit 16c33383eb
6 changed files with 198 additions and 44 deletions

View File

@ -2,8 +2,10 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
use clap::{Parser, ValueEnum};
use candle::{DType, Device, Result, Shape, Tensor, Var, D};
use candle_nn::{loss, ops, Linear};
use candle_nn::{loss, ops, Init, Linear};
use std::sync::{Arc, Mutex};
const IMAGE_DIM: usize = 784;
@ -44,7 +46,7 @@ impl VarStore {
}
}
fn get<S: Into<Shape>>(&self, shape: S, tensor_name: &str) -> Result<Tensor> {
fn get<S: Into<Shape>>(&self, shape: S, tensor_name: &str, init: Init) -> Result<Tensor> {
let shape = shape.into();
let path = if self.path.is_empty() {
tensor_name.to_string()
@ -59,8 +61,7 @@ impl VarStore {
}
return Ok(tensor.as_tensor().clone());
}
// TODO: Proper initialization using the `Init` enum.
let var = Var::zeros(shape, tensor_data.dtype, &tensor_data.device)?;
let var = init.var(shape, tensor_data.dtype, &tensor_data.device)?;
let tensor = var.as_tensor().clone();
tensor_data.tensors.insert(path, var);
Ok(tensor)
@ -77,21 +78,36 @@ impl VarStore {
}
}
fn linear(dim1: usize, dim2: usize, vs: VarStore) -> Result<Linear> {
let ws = vs.get((dim2, dim1), "weight")?;
let bs = vs.get(dim2, "bias")?;
fn linear_z(in_dim: usize, out_dim: usize, vs: VarStore) -> Result<Linear> {
let ws = vs.get((out_dim, in_dim), "weight", candle_nn::init::ZERO)?;
let bs = vs.get(out_dim, "bias", candle_nn::init::ZERO)?;
Ok(Linear::new(ws, Some(bs)))
}
#[allow(unused)]
fn linear(in_dim: usize, out_dim: usize, vs: VarStore) -> Result<Linear> {
let init_ws = candle_nn::init::DEFAULT_KAIMING_NORMAL;
let ws = vs.get((out_dim, in_dim), "weight", init_ws)?;
let bound = 1. / (in_dim as f64).sqrt();
let init_bs = Init::Uniform {
lo: -bound,
up: bound,
};
let bs = vs.get(out_dim, "bias", init_bs)?;
Ok(Linear::new(ws, Some(bs)))
}
trait Model: Sized {
fn new(vs: VarStore) -> Result<Self>;
fn forward(&self, xs: &Tensor) -> Result<Tensor>;
}
struct LinearModel {
linear: Linear,
}
#[allow(unused)]
impl LinearModel {
impl Model for LinearModel {
fn new(vs: VarStore) -> Result<Self> {
let linear = linear(IMAGE_DIM, LABELS, vs)?;
let linear = linear_z(IMAGE_DIM, LABELS, vs)?;
Ok(Self { linear })
}
@ -100,14 +116,12 @@ impl LinearModel {
}
}
#[allow(unused)]
struct Mlp {
ln1: Linear,
ln2: Linear,
}
#[allow(unused)]
impl Mlp {
impl Model for Mlp {
fn new(vs: VarStore) -> Result<Self> {
let ln1 = linear(IMAGE_DIM, 100, vs.pp("ln1"))?;
let ln2 = linear(100, LABELS, vs.pp("ln2"))?;
@ -121,26 +135,22 @@ impl Mlp {
}
}
pub fn main() -> anyhow::Result<()> {
fn training_loop<M: Model>(
m: candle_nn::vision::Dataset,
learning_rate: f64,
) -> anyhow::Result<()> {
let dev = candle::Device::cuda_if_available(0)?;
// Load the dataset
let m = candle_nn::vision::mnist::load_dir("data")?;
println!("train-images: {:?}", m.train_images.shape());
println!("train-labels: {:?}", m.train_labels.shape());
println!("test-images: {:?}", m.test_images.shape());
println!("test-labels: {:?}", m.test_labels.shape());
let train_labels = m.train_labels;
let train_images = m.train_images;
let train_labels = train_labels.to_dtype(DType::U32)?.unsqueeze(1)?;
let vs = VarStore::new(DType::F32, dev);
let model = LinearModel::new(vs.clone())?;
// let model = Mlp::new(vs)?;
let model = M::new(vs.clone())?;
let all_vars = vs.all_vars();
let all_vars = all_vars.iter().collect::<Vec<_>>();
let sgd = candle_nn::SGD::new(&all_vars, 1.0);
let sgd = candle_nn::SGD::new(&all_vars, learning_rate);
let test_images = m.test_images;
let test_labels = m.test_labels.to_dtype(DType::U32)?;
for epoch in 1..200 {
@ -165,3 +175,33 @@ pub fn main() -> anyhow::Result<()> {
}
Ok(())
}
#[derive(ValueEnum, Clone)]
enum WhichModel {
Linear,
Mlp,
}
#[derive(Parser)]
struct Args {
#[clap(value_enum, default_value_t = WhichModel::Linear)]
model: WhichModel,
#[arg(long)]
learning_rate: Option<f64>,
}
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
// Load the dataset
let m = candle_nn::vision::mnist::load_dir("data")?;
println!("train-images: {:?}", m.train_images.shape());
println!("train-labels: {:?}", m.train_labels.shape());
println!("test-images: {:?}", m.test_images.shape());
println!("test-labels: {:?}", m.test_labels.shape());
match args.model {
WhichModel::Linear => training_loop::<LinearModel>(m, args.learning_rate.unwrap_or(1.)),
WhichModel::Mlp => training_loop::<Mlp>(m, args.learning_rate.unwrap_or(0.01)),
}
}