From 16c33383eb2beda515962b219728209b9edb2946 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sat, 29 Jul 2023 16:28:22 +0100 Subject: [PATCH] Improve the mnist training example. (#276) * Improve the mnist training example. * Add some initialization routine that can be used for nn. * Proper initialization in the mnist example. --- candle-core/src/device.rs | 48 ++++++---- candle-core/src/tensor.rs | 38 ++++++++ candle-core/src/variable.rs | 27 ++++++ .../examples/simple-training/main.rs | 88 ++++++++++++++----- candle-nn/src/init.rs | 40 ++++++++- candle-nn/src/lib.rs | 1 + 6 files changed, 198 insertions(+), 44 deletions(-) diff --git a/candle-core/src/device.rs b/candle-core/src/device.rs index 89df8f84..563d892b 100644 --- a/candle-core/src/device.rs +++ b/candle-core/src/device.rs @@ -116,21 +116,48 @@ impl Device { } } + pub(crate) fn rand_uniform_f64( + &self, + lo: f64, + up: f64, + shape: &Shape, + dtype: DType, + ) -> Result { + match self { + Device::Cpu => { + let storage = CpuDevice.rand_uniform(shape, dtype, lo, up)?; + Ok(Storage::Cpu(storage)) + } + Device::Cuda(device) => { + let storage = device.rand_uniform(shape, dtype, lo, up)?; + Ok(Storage::Cuda(storage)) + } + } + } + pub(crate) fn rand_uniform( &self, lo: T, up: T, shape: &Shape, ) -> Result { - let lo = lo.to_f64(); - let up = up.to_f64(); + self.rand_uniform_f64(lo.to_f64(), up.to_f64(), shape, T::DTYPE) + } + + pub(crate) fn rand_normal_f64( + &self, + mean: f64, + std: f64, + shape: &Shape, + dtype: DType, + ) -> Result { match self { Device::Cpu => { - let storage = CpuDevice.rand_uniform(shape, T::DTYPE, lo, up)?; + let storage = CpuDevice.rand_normal(shape, dtype, mean, std)?; Ok(Storage::Cpu(storage)) } Device::Cuda(device) => { - let storage = device.rand_uniform(shape, T::DTYPE, lo, up)?; + let storage = device.rand_normal(shape, dtype, mean, std)?; Ok(Storage::Cuda(storage)) } } @@ -142,18 +169,7 @@ impl Device { std: T, shape: &Shape, ) -> Result { - let mean = mean.to_f64(); - let std = std.to_f64(); - match self { - Device::Cpu => { - let storage = CpuDevice.rand_normal(shape, T::DTYPE, mean, std)?; - Ok(Storage::Cpu(storage)) - } - Device::Cuda(device) => { - let storage = device.rand_normal(shape, T::DTYPE, mean, std)?; - Ok(Storage::Cuda(storage)) - } - } + self.rand_normal_f64(mean.to_f64(), std.to_f64(), shape, T::DTYPE) } pub(crate) fn ones(&self, shape: &Shape, dtype: DType) -> Result { diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 8ae92c2e..060e8792 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -245,6 +245,20 @@ impl Tensor { Ok(from_storage(storage, s, none, is_variable)) } + pub(crate) fn rand_f64_impl>( + lo: f64, + up: f64, + s: S, + dtype: DType, + device: &Device, + is_variable: bool, + ) -> Result { + let s = s.into(); + let storage = device.rand_uniform_f64(lo, up, &s, dtype)?; + let none = BackpropOp::none(); + Ok(from_storage(storage, s, none, is_variable)) + } + /// Creates a new tensor initialized with values sampled uniformly between `lo` and `up`. pub fn rand, T: crate::FloatDType>( lo: T, @@ -268,6 +282,20 @@ impl Tensor { Ok(from_storage(storage, s, none, is_variable)) } + pub(crate) fn randn_f64_impl>( + mean: f64, + std: f64, + s: S, + dtype: DType, + device: &Device, + is_variable: bool, + ) -> Result { + let s = s.into(); + let storage = device.rand_normal_f64(mean, std, &s, dtype)?; + let none = BackpropOp::none(); + Ok(from_storage(storage, s, none, is_variable)) + } + /// Creates a new tensor initialized with values sampled from a normal distribution with the /// specified `mean` and standard deviation `std`. pub fn randn, T: crate::FloatDType>( @@ -1448,6 +1476,16 @@ impl Tensor { } } + /// Create a variable based on the values currently stored in a tensor. The storage is always + /// copied. + pub(crate) fn make_var(&self) -> Result { + let shape = self.shape().clone(); + let mut storage = self.device().zeros(&shape, self.dtype())?; + self.storage() + .copy_strided_src(&mut storage, 0, self.layout())?; + Ok(from_storage(storage, shape, BackpropOp::none(), true)) + } + // TODO: Do we want to allow target shape using -1 on some dimensions? /// Reshape returns a tensor with the target shape provided that the number of elements of the /// original tensor is the same. diff --git a/candle-core/src/variable.rs b/candle-core/src/variable.rs index 0cefee11..61800bf3 100644 --- a/candle-core/src/variable.rs +++ b/candle-core/src/variable.rs @@ -34,6 +34,33 @@ impl Var { Ok(Self(inner)) } + pub fn from_tensor(t: &Tensor) -> Result { + let inner = t.make_var()?; + Ok(Self(inner)) + } + + pub fn rand_f64>( + lo: f64, + up: f64, + s: S, + dtype: DType, + device: &Device, + ) -> Result { + let inner = Tensor::rand_f64_impl(lo, up, s, dtype, device, true)?; + Ok(Self(inner)) + } + + pub fn randn_f64>( + mean: f64, + std: f64, + s: S, + dtype: DType, + device: &Device, + ) -> Result { + let inner = Tensor::randn_f64_impl(mean, std, s, dtype, device, true)?; + Ok(Self(inner)) + } + pub fn rand, T: crate::FloatDType>( lo: T, up: T, diff --git a/candle-examples/examples/simple-training/main.rs b/candle-examples/examples/simple-training/main.rs index 35b938e8..f15aa60c 100644 --- a/candle-examples/examples/simple-training/main.rs +++ b/candle-examples/examples/simple-training/main.rs @@ -2,8 +2,10 @@ #[cfg(feature = "mkl")] extern crate intel_mkl_src; +use clap::{Parser, ValueEnum}; + use candle::{DType, Device, Result, Shape, Tensor, Var, D}; -use candle_nn::{loss, ops, Linear}; +use candle_nn::{loss, ops, Init, Linear}; use std::sync::{Arc, Mutex}; const IMAGE_DIM: usize = 784; @@ -44,7 +46,7 @@ impl VarStore { } } - fn get>(&self, shape: S, tensor_name: &str) -> Result { + fn get>(&self, shape: S, tensor_name: &str, init: Init) -> Result { let shape = shape.into(); let path = if self.path.is_empty() { tensor_name.to_string() @@ -59,8 +61,7 @@ impl VarStore { } return Ok(tensor.as_tensor().clone()); } - // TODO: Proper initialization using the `Init` enum. - let var = Var::zeros(shape, tensor_data.dtype, &tensor_data.device)?; + let var = init.var(shape, tensor_data.dtype, &tensor_data.device)?; let tensor = var.as_tensor().clone(); tensor_data.tensors.insert(path, var); Ok(tensor) @@ -77,21 +78,36 @@ impl VarStore { } } -fn linear(dim1: usize, dim2: usize, vs: VarStore) -> Result { - let ws = vs.get((dim2, dim1), "weight")?; - let bs = vs.get(dim2, "bias")?; +fn linear_z(in_dim: usize, out_dim: usize, vs: VarStore) -> Result { + let ws = vs.get((out_dim, in_dim), "weight", candle_nn::init::ZERO)?; + let bs = vs.get(out_dim, "bias", candle_nn::init::ZERO)?; Ok(Linear::new(ws, Some(bs))) } -#[allow(unused)] +fn linear(in_dim: usize, out_dim: usize, vs: VarStore) -> Result { + let init_ws = candle_nn::init::DEFAULT_KAIMING_NORMAL; + let ws = vs.get((out_dim, in_dim), "weight", init_ws)?; + let bound = 1. / (in_dim as f64).sqrt(); + let init_bs = Init::Uniform { + lo: -bound, + up: bound, + }; + let bs = vs.get(out_dim, "bias", init_bs)?; + Ok(Linear::new(ws, Some(bs))) +} + +trait Model: Sized { + fn new(vs: VarStore) -> Result; + fn forward(&self, xs: &Tensor) -> Result; +} + struct LinearModel { linear: Linear, } -#[allow(unused)] -impl LinearModel { +impl Model for LinearModel { fn new(vs: VarStore) -> Result { - let linear = linear(IMAGE_DIM, LABELS, vs)?; + let linear = linear_z(IMAGE_DIM, LABELS, vs)?; Ok(Self { linear }) } @@ -100,14 +116,12 @@ impl LinearModel { } } -#[allow(unused)] struct Mlp { ln1: Linear, ln2: Linear, } -#[allow(unused)] -impl Mlp { +impl Model for Mlp { fn new(vs: VarStore) -> Result { let ln1 = linear(IMAGE_DIM, 100, vs.pp("ln1"))?; let ln2 = linear(100, LABELS, vs.pp("ln2"))?; @@ -121,26 +135,22 @@ impl Mlp { } } -pub fn main() -> anyhow::Result<()> { +fn training_loop( + m: candle_nn::vision::Dataset, + learning_rate: f64, +) -> anyhow::Result<()> { let dev = candle::Device::cuda_if_available(0)?; - // Load the dataset - let m = candle_nn::vision::mnist::load_dir("data")?; - println!("train-images: {:?}", m.train_images.shape()); - println!("train-labels: {:?}", m.train_labels.shape()); - println!("test-images: {:?}", m.test_images.shape()); - println!("test-labels: {:?}", m.test_labels.shape()); let train_labels = m.train_labels; let train_images = m.train_images; let train_labels = train_labels.to_dtype(DType::U32)?.unsqueeze(1)?; let vs = VarStore::new(DType::F32, dev); - let model = LinearModel::new(vs.clone())?; - // let model = Mlp::new(vs)?; + let model = M::new(vs.clone())?; let all_vars = vs.all_vars(); let all_vars = all_vars.iter().collect::>(); - let sgd = candle_nn::SGD::new(&all_vars, 1.0); + let sgd = candle_nn::SGD::new(&all_vars, learning_rate); let test_images = m.test_images; let test_labels = m.test_labels.to_dtype(DType::U32)?; for epoch in 1..200 { @@ -165,3 +175,33 @@ pub fn main() -> anyhow::Result<()> { } Ok(()) } + +#[derive(ValueEnum, Clone)] +enum WhichModel { + Linear, + Mlp, +} + +#[derive(Parser)] +struct Args { + #[clap(value_enum, default_value_t = WhichModel::Linear)] + model: WhichModel, + + #[arg(long)] + learning_rate: Option, +} + +pub fn main() -> anyhow::Result<()> { + let args = Args::parse(); + // Load the dataset + let m = candle_nn::vision::mnist::load_dir("data")?; + println!("train-images: {:?}", m.train_images.shape()); + println!("train-labels: {:?}", m.train_labels.shape()); + println!("test-images: {:?}", m.test_images.shape()); + println!("test-labels: {:?}", m.test_labels.shape()); + + match args.model { + WhichModel::Linear => training_loop::(m, args.learning_rate.unwrap_or(1.)), + WhichModel::Mlp => training_loop::(m, args.learning_rate.unwrap_or(0.01)), + } +} diff --git a/candle-nn/src/init.rs b/candle-nn/src/init.rs index 762f0ef1..25702d52 100644 --- a/candle-nn/src/init.rs +++ b/candle-nn/src/init.rs @@ -1,7 +1,7 @@ //! Variable initialization. // This is based on: // https://github.com/pytorch/pytorch/blob/07107919297db3f8ab37f11c12666b6d6d5f692e/torch/nn/init.py# -use candle::Shape; +use candle::{DType, Device, Result, Shape, Tensor, Var}; /// Number of features as input or output of a layer. /// In Kaiming initialization, choosing `FanIn` preserves @@ -91,11 +91,11 @@ pub enum Init { fan: FanInOut, non_linearity: NonLinearity, }, - - /// Orthogonal initialization - Orthogonal { gain: f64 }, } +pub const ZERO: Init = Init::Const(0.); +pub const ONE: Init = Init::Const(1.); + pub const DEFAULT_KAIMING_UNIFORM: Init = Init::Kaiming { dist: NormalOrUniform::Uniform, fan: FanInOut::FanIn, @@ -107,3 +107,35 @@ pub const DEFAULT_KAIMING_NORMAL: Init = Init::Kaiming { fan: FanInOut::FanIn, non_linearity: NonLinearity::ReLU, }; + +impl Init { + /// Creates a new tensor with the specified shape, device, and initialization. + pub fn var>(&self, s: S, dtype: DType, device: &Device) -> Result { + match self { + Self::Const(v) if *v == 0. => Var::zeros(s, dtype, device), + Self::Const(v) if *v == 1. => Var::ones(s, dtype, device), + Self::Const(cst) => { + Var::from_tensor(&Tensor::ones(s, dtype, device)?.affine(*cst, 0.)?) + } + Self::Uniform { lo, up } => Var::rand_f64(*lo, *up, s, dtype, device), + Self::Randn { mean, stdev } => Var::randn_f64(*mean, *stdev, s, dtype, device), + Self::Kaiming { + dist, + fan, + non_linearity, + } => { + let s = s.into(); + let fan = fan.for_shape(&s); + let gain = non_linearity.gain(); + let std = gain / (fan as f64).sqrt(); + match dist { + NormalOrUniform::Uniform => { + let bound = 3f64.sqrt() * std; + Var::rand_f64(-bound, bound, s, dtype, device) + } + NormalOrUniform::Normal => Var::randn_f64(0., std, s, dtype, device), + } + } + } + } +} diff --git a/candle-nn/src/lib.rs b/candle-nn/src/lib.rs index db01b067..d0b62dbb 100644 --- a/candle-nn/src/lib.rs +++ b/candle-nn/src/lib.rs @@ -15,6 +15,7 @@ pub mod vision; pub use activation::Activation; pub use conv::{Conv1d, Conv1dConfig}; pub use embedding::Embedding; +pub use init::Init; pub use layer_norm::LayerNorm; pub use linear::Linear; pub use optim::SGD;