Improve the mnist training example. (#276)

* Improve the mnist training example.

* Add some initialization routine that can be used for nn.

* Proper initialization in the mnist example.
This commit is contained in:
Laurent Mazare
2023-07-29 16:28:22 +01:00
committed by GitHub
parent bedcef64dc
commit 16c33383eb
6 changed files with 198 additions and 44 deletions

View File

@ -116,21 +116,48 @@ impl Device {
} }
} }
pub(crate) fn rand_uniform_f64(
&self,
lo: f64,
up: f64,
shape: &Shape,
dtype: DType,
) -> Result<Storage> {
match self {
Device::Cpu => {
let storage = CpuDevice.rand_uniform(shape, dtype, lo, up)?;
Ok(Storage::Cpu(storage))
}
Device::Cuda(device) => {
let storage = device.rand_uniform(shape, dtype, lo, up)?;
Ok(Storage::Cuda(storage))
}
}
}
pub(crate) fn rand_uniform<T: crate::FloatDType>( pub(crate) fn rand_uniform<T: crate::FloatDType>(
&self, &self,
lo: T, lo: T,
up: T, up: T,
shape: &Shape, shape: &Shape,
) -> Result<Storage> { ) -> Result<Storage> {
let lo = lo.to_f64(); self.rand_uniform_f64(lo.to_f64(), up.to_f64(), shape, T::DTYPE)
let up = up.to_f64(); }
pub(crate) fn rand_normal_f64(
&self,
mean: f64,
std: f64,
shape: &Shape,
dtype: DType,
) -> Result<Storage> {
match self { match self {
Device::Cpu => { Device::Cpu => {
let storage = CpuDevice.rand_uniform(shape, T::DTYPE, lo, up)?; let storage = CpuDevice.rand_normal(shape, dtype, mean, std)?;
Ok(Storage::Cpu(storage)) Ok(Storage::Cpu(storage))
} }
Device::Cuda(device) => { Device::Cuda(device) => {
let storage = device.rand_uniform(shape, T::DTYPE, lo, up)?; let storage = device.rand_normal(shape, dtype, mean, std)?;
Ok(Storage::Cuda(storage)) Ok(Storage::Cuda(storage))
} }
} }
@ -142,18 +169,7 @@ impl Device {
std: T, std: T,
shape: &Shape, shape: &Shape,
) -> Result<Storage> { ) -> Result<Storage> {
let mean = mean.to_f64(); self.rand_normal_f64(mean.to_f64(), std.to_f64(), shape, T::DTYPE)
let std = std.to_f64();
match self {
Device::Cpu => {
let storage = CpuDevice.rand_normal(shape, T::DTYPE, mean, std)?;
Ok(Storage::Cpu(storage))
}
Device::Cuda(device) => {
let storage = device.rand_normal(shape, T::DTYPE, mean, std)?;
Ok(Storage::Cuda(storage))
}
}
} }
pub(crate) fn ones(&self, shape: &Shape, dtype: DType) -> Result<Storage> { pub(crate) fn ones(&self, shape: &Shape, dtype: DType) -> Result<Storage> {

View File

@ -245,6 +245,20 @@ impl Tensor {
Ok(from_storage(storage, s, none, is_variable)) Ok(from_storage(storage, s, none, is_variable))
} }
pub(crate) fn rand_f64_impl<S: Into<Shape>>(
lo: f64,
up: f64,
s: S,
dtype: DType,
device: &Device,
is_variable: bool,
) -> Result<Self> {
let s = s.into();
let storage = device.rand_uniform_f64(lo, up, &s, dtype)?;
let none = BackpropOp::none();
Ok(from_storage(storage, s, none, is_variable))
}
/// Creates a new tensor initialized with values sampled uniformly between `lo` and `up`. /// Creates a new tensor initialized with values sampled uniformly between `lo` and `up`.
pub fn rand<S: Into<Shape>, T: crate::FloatDType>( pub fn rand<S: Into<Shape>, T: crate::FloatDType>(
lo: T, lo: T,
@ -268,6 +282,20 @@ impl Tensor {
Ok(from_storage(storage, s, none, is_variable)) Ok(from_storage(storage, s, none, is_variable))
} }
pub(crate) fn randn_f64_impl<S: Into<Shape>>(
mean: f64,
std: f64,
s: S,
dtype: DType,
device: &Device,
is_variable: bool,
) -> Result<Self> {
let s = s.into();
let storage = device.rand_normal_f64(mean, std, &s, dtype)?;
let none = BackpropOp::none();
Ok(from_storage(storage, s, none, is_variable))
}
/// Creates a new tensor initialized with values sampled from a normal distribution with the /// Creates a new tensor initialized with values sampled from a normal distribution with the
/// specified `mean` and standard deviation `std`. /// specified `mean` and standard deviation `std`.
pub fn randn<S: Into<Shape>, T: crate::FloatDType>( pub fn randn<S: Into<Shape>, T: crate::FloatDType>(
@ -1448,6 +1476,16 @@ impl Tensor {
} }
} }
/// Create a variable based on the values currently stored in a tensor. The storage is always
/// copied.
pub(crate) fn make_var(&self) -> Result<Tensor> {
let shape = self.shape().clone();
let mut storage = self.device().zeros(&shape, self.dtype())?;
self.storage()
.copy_strided_src(&mut storage, 0, self.layout())?;
Ok(from_storage(storage, shape, BackpropOp::none(), true))
}
// TODO: Do we want to allow target shape using -1 on some dimensions? // TODO: Do we want to allow target shape using -1 on some dimensions?
/// Reshape returns a tensor with the target shape provided that the number of elements of the /// Reshape returns a tensor with the target shape provided that the number of elements of the
/// original tensor is the same. /// original tensor is the same.

View File

@ -34,6 +34,33 @@ impl Var {
Ok(Self(inner)) Ok(Self(inner))
} }
pub fn from_tensor(t: &Tensor) -> Result<Self> {
let inner = t.make_var()?;
Ok(Self(inner))
}
pub fn rand_f64<S: Into<Shape>>(
lo: f64,
up: f64,
s: S,
dtype: DType,
device: &Device,
) -> Result<Self> {
let inner = Tensor::rand_f64_impl(lo, up, s, dtype, device, true)?;
Ok(Self(inner))
}
pub fn randn_f64<S: Into<Shape>>(
mean: f64,
std: f64,
s: S,
dtype: DType,
device: &Device,
) -> Result<Self> {
let inner = Tensor::randn_f64_impl(mean, std, s, dtype, device, true)?;
Ok(Self(inner))
}
pub fn rand<S: Into<Shape>, T: crate::FloatDType>( pub fn rand<S: Into<Shape>, T: crate::FloatDType>(
lo: T, lo: T,
up: T, up: T,

View File

@ -2,8 +2,10 @@
#[cfg(feature = "mkl")] #[cfg(feature = "mkl")]
extern crate intel_mkl_src; extern crate intel_mkl_src;
use clap::{Parser, ValueEnum};
use candle::{DType, Device, Result, Shape, Tensor, Var, D}; use candle::{DType, Device, Result, Shape, Tensor, Var, D};
use candle_nn::{loss, ops, Linear}; use candle_nn::{loss, ops, Init, Linear};
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
const IMAGE_DIM: usize = 784; const IMAGE_DIM: usize = 784;
@ -44,7 +46,7 @@ impl VarStore {
} }
} }
fn get<S: Into<Shape>>(&self, shape: S, tensor_name: &str) -> Result<Tensor> { fn get<S: Into<Shape>>(&self, shape: S, tensor_name: &str, init: Init) -> Result<Tensor> {
let shape = shape.into(); let shape = shape.into();
let path = if self.path.is_empty() { let path = if self.path.is_empty() {
tensor_name.to_string() tensor_name.to_string()
@ -59,8 +61,7 @@ impl VarStore {
} }
return Ok(tensor.as_tensor().clone()); return Ok(tensor.as_tensor().clone());
} }
// TODO: Proper initialization using the `Init` enum. let var = init.var(shape, tensor_data.dtype, &tensor_data.device)?;
let var = Var::zeros(shape, tensor_data.dtype, &tensor_data.device)?;
let tensor = var.as_tensor().clone(); let tensor = var.as_tensor().clone();
tensor_data.tensors.insert(path, var); tensor_data.tensors.insert(path, var);
Ok(tensor) Ok(tensor)
@ -77,21 +78,36 @@ impl VarStore {
} }
} }
fn linear(dim1: usize, dim2: usize, vs: VarStore) -> Result<Linear> { fn linear_z(in_dim: usize, out_dim: usize, vs: VarStore) -> Result<Linear> {
let ws = vs.get((dim2, dim1), "weight")?; let ws = vs.get((out_dim, in_dim), "weight", candle_nn::init::ZERO)?;
let bs = vs.get(dim2, "bias")?; let bs = vs.get(out_dim, "bias", candle_nn::init::ZERO)?;
Ok(Linear::new(ws, Some(bs))) Ok(Linear::new(ws, Some(bs)))
} }
#[allow(unused)] fn linear(in_dim: usize, out_dim: usize, vs: VarStore) -> Result<Linear> {
let init_ws = candle_nn::init::DEFAULT_KAIMING_NORMAL;
let ws = vs.get((out_dim, in_dim), "weight", init_ws)?;
let bound = 1. / (in_dim as f64).sqrt();
let init_bs = Init::Uniform {
lo: -bound,
up: bound,
};
let bs = vs.get(out_dim, "bias", init_bs)?;
Ok(Linear::new(ws, Some(bs)))
}
trait Model: Sized {
fn new(vs: VarStore) -> Result<Self>;
fn forward(&self, xs: &Tensor) -> Result<Tensor>;
}
struct LinearModel { struct LinearModel {
linear: Linear, linear: Linear,
} }
#[allow(unused)] impl Model for LinearModel {
impl LinearModel {
fn new(vs: VarStore) -> Result<Self> { fn new(vs: VarStore) -> Result<Self> {
let linear = linear(IMAGE_DIM, LABELS, vs)?; let linear = linear_z(IMAGE_DIM, LABELS, vs)?;
Ok(Self { linear }) Ok(Self { linear })
} }
@ -100,14 +116,12 @@ impl LinearModel {
} }
} }
#[allow(unused)]
struct Mlp { struct Mlp {
ln1: Linear, ln1: Linear,
ln2: Linear, ln2: Linear,
} }
#[allow(unused)] impl Model for Mlp {
impl Mlp {
fn new(vs: VarStore) -> Result<Self> { fn new(vs: VarStore) -> Result<Self> {
let ln1 = linear(IMAGE_DIM, 100, vs.pp("ln1"))?; let ln1 = linear(IMAGE_DIM, 100, vs.pp("ln1"))?;
let ln2 = linear(100, LABELS, vs.pp("ln2"))?; let ln2 = linear(100, LABELS, vs.pp("ln2"))?;
@ -121,26 +135,22 @@ impl Mlp {
} }
} }
pub fn main() -> anyhow::Result<()> { fn training_loop<M: Model>(
m: candle_nn::vision::Dataset,
learning_rate: f64,
) -> anyhow::Result<()> {
let dev = candle::Device::cuda_if_available(0)?; let dev = candle::Device::cuda_if_available(0)?;
// Load the dataset
let m = candle_nn::vision::mnist::load_dir("data")?;
println!("train-images: {:?}", m.train_images.shape());
println!("train-labels: {:?}", m.train_labels.shape());
println!("test-images: {:?}", m.test_images.shape());
println!("test-labels: {:?}", m.test_labels.shape());
let train_labels = m.train_labels; let train_labels = m.train_labels;
let train_images = m.train_images; let train_images = m.train_images;
let train_labels = train_labels.to_dtype(DType::U32)?.unsqueeze(1)?; let train_labels = train_labels.to_dtype(DType::U32)?.unsqueeze(1)?;
let vs = VarStore::new(DType::F32, dev); let vs = VarStore::new(DType::F32, dev);
let model = LinearModel::new(vs.clone())?; let model = M::new(vs.clone())?;
// let model = Mlp::new(vs)?;
let all_vars = vs.all_vars(); let all_vars = vs.all_vars();
let all_vars = all_vars.iter().collect::<Vec<_>>(); let all_vars = all_vars.iter().collect::<Vec<_>>();
let sgd = candle_nn::SGD::new(&all_vars, 1.0); let sgd = candle_nn::SGD::new(&all_vars, learning_rate);
let test_images = m.test_images; let test_images = m.test_images;
let test_labels = m.test_labels.to_dtype(DType::U32)?; let test_labels = m.test_labels.to_dtype(DType::U32)?;
for epoch in 1..200 { for epoch in 1..200 {
@ -165,3 +175,33 @@ pub fn main() -> anyhow::Result<()> {
} }
Ok(()) Ok(())
} }
#[derive(ValueEnum, Clone)]
enum WhichModel {
Linear,
Mlp,
}
#[derive(Parser)]
struct Args {
#[clap(value_enum, default_value_t = WhichModel::Linear)]
model: WhichModel,
#[arg(long)]
learning_rate: Option<f64>,
}
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
// Load the dataset
let m = candle_nn::vision::mnist::load_dir("data")?;
println!("train-images: {:?}", m.train_images.shape());
println!("train-labels: {:?}", m.train_labels.shape());
println!("test-images: {:?}", m.test_images.shape());
println!("test-labels: {:?}", m.test_labels.shape());
match args.model {
WhichModel::Linear => training_loop::<LinearModel>(m, args.learning_rate.unwrap_or(1.)),
WhichModel::Mlp => training_loop::<Mlp>(m, args.learning_rate.unwrap_or(0.01)),
}
}

View File

@ -1,7 +1,7 @@
//! Variable initialization. //! Variable initialization.
// This is based on: // This is based on:
// https://github.com/pytorch/pytorch/blob/07107919297db3f8ab37f11c12666b6d6d5f692e/torch/nn/init.py# // https://github.com/pytorch/pytorch/blob/07107919297db3f8ab37f11c12666b6d6d5f692e/torch/nn/init.py#
use candle::Shape; use candle::{DType, Device, Result, Shape, Tensor, Var};
/// Number of features as input or output of a layer. /// Number of features as input or output of a layer.
/// In Kaiming initialization, choosing `FanIn` preserves /// In Kaiming initialization, choosing `FanIn` preserves
@ -91,11 +91,11 @@ pub enum Init {
fan: FanInOut, fan: FanInOut,
non_linearity: NonLinearity, non_linearity: NonLinearity,
}, },
/// Orthogonal initialization
Orthogonal { gain: f64 },
} }
pub const ZERO: Init = Init::Const(0.);
pub const ONE: Init = Init::Const(1.);
pub const DEFAULT_KAIMING_UNIFORM: Init = Init::Kaiming { pub const DEFAULT_KAIMING_UNIFORM: Init = Init::Kaiming {
dist: NormalOrUniform::Uniform, dist: NormalOrUniform::Uniform,
fan: FanInOut::FanIn, fan: FanInOut::FanIn,
@ -107,3 +107,35 @@ pub const DEFAULT_KAIMING_NORMAL: Init = Init::Kaiming {
fan: FanInOut::FanIn, fan: FanInOut::FanIn,
non_linearity: NonLinearity::ReLU, non_linearity: NonLinearity::ReLU,
}; };
impl Init {
/// Creates a new tensor with the specified shape, device, and initialization.
pub fn var<S: Into<Shape>>(&self, s: S, dtype: DType, device: &Device) -> Result<Var> {
match self {
Self::Const(v) if *v == 0. => Var::zeros(s, dtype, device),
Self::Const(v) if *v == 1. => Var::ones(s, dtype, device),
Self::Const(cst) => {
Var::from_tensor(&Tensor::ones(s, dtype, device)?.affine(*cst, 0.)?)
}
Self::Uniform { lo, up } => Var::rand_f64(*lo, *up, s, dtype, device),
Self::Randn { mean, stdev } => Var::randn_f64(*mean, *stdev, s, dtype, device),
Self::Kaiming {
dist,
fan,
non_linearity,
} => {
let s = s.into();
let fan = fan.for_shape(&s);
let gain = non_linearity.gain();
let std = gain / (fan as f64).sqrt();
match dist {
NormalOrUniform::Uniform => {
let bound = 3f64.sqrt() * std;
Var::rand_f64(-bound, bound, s, dtype, device)
}
NormalOrUniform::Normal => Var::randn_f64(0., std, s, dtype, device),
}
}
}
}
}

View File

@ -15,6 +15,7 @@ pub mod vision;
pub use activation::Activation; pub use activation::Activation;
pub use conv::{Conv1d, Conv1dConfig}; pub use conv::{Conv1d, Conv1dConfig};
pub use embedding::Embedding; pub use embedding::Embedding;
pub use init::Init;
pub use layer_norm::LayerNorm; pub use layer_norm::LayerNorm;
pub use linear::Linear; pub use linear::Linear;
pub use optim::SGD; pub use optim::SGD;