Improve the mnist training example. (#276)

* Improve the mnist training example.

* Add some initialization routine that can be used for nn.

* Proper initialization in the mnist example.
This commit is contained in:
Laurent Mazare
2023-07-29 16:28:22 +01:00
committed by GitHub
parent bedcef64dc
commit 16c33383eb
6 changed files with 198 additions and 44 deletions

View File

@ -116,21 +116,48 @@ impl Device {
}
}
pub(crate) fn rand_uniform_f64(
&self,
lo: f64,
up: f64,
shape: &Shape,
dtype: DType,
) -> Result<Storage> {
match self {
Device::Cpu => {
let storage = CpuDevice.rand_uniform(shape, dtype, lo, up)?;
Ok(Storage::Cpu(storage))
}
Device::Cuda(device) => {
let storage = device.rand_uniform(shape, dtype, lo, up)?;
Ok(Storage::Cuda(storage))
}
}
}
pub(crate) fn rand_uniform<T: crate::FloatDType>(
&self,
lo: T,
up: T,
shape: &Shape,
) -> Result<Storage> {
let lo = lo.to_f64();
let up = up.to_f64();
self.rand_uniform_f64(lo.to_f64(), up.to_f64(), shape, T::DTYPE)
}
pub(crate) fn rand_normal_f64(
&self,
mean: f64,
std: f64,
shape: &Shape,
dtype: DType,
) -> Result<Storage> {
match self {
Device::Cpu => {
let storage = CpuDevice.rand_uniform(shape, T::DTYPE, lo, up)?;
let storage = CpuDevice.rand_normal(shape, dtype, mean, std)?;
Ok(Storage::Cpu(storage))
}
Device::Cuda(device) => {
let storage = device.rand_uniform(shape, T::DTYPE, lo, up)?;
let storage = device.rand_normal(shape, dtype, mean, std)?;
Ok(Storage::Cuda(storage))
}
}
@ -142,18 +169,7 @@ impl Device {
std: T,
shape: &Shape,
) -> Result<Storage> {
let mean = mean.to_f64();
let std = std.to_f64();
match self {
Device::Cpu => {
let storage = CpuDevice.rand_normal(shape, T::DTYPE, mean, std)?;
Ok(Storage::Cpu(storage))
}
Device::Cuda(device) => {
let storage = device.rand_normal(shape, T::DTYPE, mean, std)?;
Ok(Storage::Cuda(storage))
}
}
self.rand_normal_f64(mean.to_f64(), std.to_f64(), shape, T::DTYPE)
}
pub(crate) fn ones(&self, shape: &Shape, dtype: DType) -> Result<Storage> {

View File

@ -245,6 +245,20 @@ impl Tensor {
Ok(from_storage(storage, s, none, is_variable))
}
pub(crate) fn rand_f64_impl<S: Into<Shape>>(
lo: f64,
up: f64,
s: S,
dtype: DType,
device: &Device,
is_variable: bool,
) -> Result<Self> {
let s = s.into();
let storage = device.rand_uniform_f64(lo, up, &s, dtype)?;
let none = BackpropOp::none();
Ok(from_storage(storage, s, none, is_variable))
}
/// Creates a new tensor initialized with values sampled uniformly between `lo` and `up`.
pub fn rand<S: Into<Shape>, T: crate::FloatDType>(
lo: T,
@ -268,6 +282,20 @@ impl Tensor {
Ok(from_storage(storage, s, none, is_variable))
}
pub(crate) fn randn_f64_impl<S: Into<Shape>>(
mean: f64,
std: f64,
s: S,
dtype: DType,
device: &Device,
is_variable: bool,
) -> Result<Self> {
let s = s.into();
let storage = device.rand_normal_f64(mean, std, &s, dtype)?;
let none = BackpropOp::none();
Ok(from_storage(storage, s, none, is_variable))
}
/// Creates a new tensor initialized with values sampled from a normal distribution with the
/// specified `mean` and standard deviation `std`.
pub fn randn<S: Into<Shape>, T: crate::FloatDType>(
@ -1448,6 +1476,16 @@ impl Tensor {
}
}
/// Create a variable based on the values currently stored in a tensor. The storage is always
/// copied.
pub(crate) fn make_var(&self) -> Result<Tensor> {
let shape = self.shape().clone();
let mut storage = self.device().zeros(&shape, self.dtype())?;
self.storage()
.copy_strided_src(&mut storage, 0, self.layout())?;
Ok(from_storage(storage, shape, BackpropOp::none(), true))
}
// TODO: Do we want to allow target shape using -1 on some dimensions?
/// Reshape returns a tensor with the target shape provided that the number of elements of the
/// original tensor is the same.

View File

@ -34,6 +34,33 @@ impl Var {
Ok(Self(inner))
}
pub fn from_tensor(t: &Tensor) -> Result<Self> {
let inner = t.make_var()?;
Ok(Self(inner))
}
pub fn rand_f64<S: Into<Shape>>(
lo: f64,
up: f64,
s: S,
dtype: DType,
device: &Device,
) -> Result<Self> {
let inner = Tensor::rand_f64_impl(lo, up, s, dtype, device, true)?;
Ok(Self(inner))
}
pub fn randn_f64<S: Into<Shape>>(
mean: f64,
std: f64,
s: S,
dtype: DType,
device: &Device,
) -> Result<Self> {
let inner = Tensor::randn_f64_impl(mean, std, s, dtype, device, true)?;
Ok(Self(inner))
}
pub fn rand<S: Into<Shape>, T: crate::FloatDType>(
lo: T,
up: T,

View File

@ -2,8 +2,10 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
use clap::{Parser, ValueEnum};
use candle::{DType, Device, Result, Shape, Tensor, Var, D};
use candle_nn::{loss, ops, Linear};
use candle_nn::{loss, ops, Init, Linear};
use std::sync::{Arc, Mutex};
const IMAGE_DIM: usize = 784;
@ -44,7 +46,7 @@ impl VarStore {
}
}
fn get<S: Into<Shape>>(&self, shape: S, tensor_name: &str) -> Result<Tensor> {
fn get<S: Into<Shape>>(&self, shape: S, tensor_name: &str, init: Init) -> Result<Tensor> {
let shape = shape.into();
let path = if self.path.is_empty() {
tensor_name.to_string()
@ -59,8 +61,7 @@ impl VarStore {
}
return Ok(tensor.as_tensor().clone());
}
// TODO: Proper initialization using the `Init` enum.
let var = Var::zeros(shape, tensor_data.dtype, &tensor_data.device)?;
let var = init.var(shape, tensor_data.dtype, &tensor_data.device)?;
let tensor = var.as_tensor().clone();
tensor_data.tensors.insert(path, var);
Ok(tensor)
@ -77,21 +78,36 @@ impl VarStore {
}
}
fn linear(dim1: usize, dim2: usize, vs: VarStore) -> Result<Linear> {
let ws = vs.get((dim2, dim1), "weight")?;
let bs = vs.get(dim2, "bias")?;
fn linear_z(in_dim: usize, out_dim: usize, vs: VarStore) -> Result<Linear> {
let ws = vs.get((out_dim, in_dim), "weight", candle_nn::init::ZERO)?;
let bs = vs.get(out_dim, "bias", candle_nn::init::ZERO)?;
Ok(Linear::new(ws, Some(bs)))
}
#[allow(unused)]
fn linear(in_dim: usize, out_dim: usize, vs: VarStore) -> Result<Linear> {
let init_ws = candle_nn::init::DEFAULT_KAIMING_NORMAL;
let ws = vs.get((out_dim, in_dim), "weight", init_ws)?;
let bound = 1. / (in_dim as f64).sqrt();
let init_bs = Init::Uniform {
lo: -bound,
up: bound,
};
let bs = vs.get(out_dim, "bias", init_bs)?;
Ok(Linear::new(ws, Some(bs)))
}
trait Model: Sized {
fn new(vs: VarStore) -> Result<Self>;
fn forward(&self, xs: &Tensor) -> Result<Tensor>;
}
struct LinearModel {
linear: Linear,
}
#[allow(unused)]
impl LinearModel {
impl Model for LinearModel {
fn new(vs: VarStore) -> Result<Self> {
let linear = linear(IMAGE_DIM, LABELS, vs)?;
let linear = linear_z(IMAGE_DIM, LABELS, vs)?;
Ok(Self { linear })
}
@ -100,14 +116,12 @@ impl LinearModel {
}
}
#[allow(unused)]
struct Mlp {
ln1: Linear,
ln2: Linear,
}
#[allow(unused)]
impl Mlp {
impl Model for Mlp {
fn new(vs: VarStore) -> Result<Self> {
let ln1 = linear(IMAGE_DIM, 100, vs.pp("ln1"))?;
let ln2 = linear(100, LABELS, vs.pp("ln2"))?;
@ -121,26 +135,22 @@ impl Mlp {
}
}
pub fn main() -> anyhow::Result<()> {
fn training_loop<M: Model>(
m: candle_nn::vision::Dataset,
learning_rate: f64,
) -> anyhow::Result<()> {
let dev = candle::Device::cuda_if_available(0)?;
// Load the dataset
let m = candle_nn::vision::mnist::load_dir("data")?;
println!("train-images: {:?}", m.train_images.shape());
println!("train-labels: {:?}", m.train_labels.shape());
println!("test-images: {:?}", m.test_images.shape());
println!("test-labels: {:?}", m.test_labels.shape());
let train_labels = m.train_labels;
let train_images = m.train_images;
let train_labels = train_labels.to_dtype(DType::U32)?.unsqueeze(1)?;
let vs = VarStore::new(DType::F32, dev);
let model = LinearModel::new(vs.clone())?;
// let model = Mlp::new(vs)?;
let model = M::new(vs.clone())?;
let all_vars = vs.all_vars();
let all_vars = all_vars.iter().collect::<Vec<_>>();
let sgd = candle_nn::SGD::new(&all_vars, 1.0);
let sgd = candle_nn::SGD::new(&all_vars, learning_rate);
let test_images = m.test_images;
let test_labels = m.test_labels.to_dtype(DType::U32)?;
for epoch in 1..200 {
@ -165,3 +175,33 @@ pub fn main() -> anyhow::Result<()> {
}
Ok(())
}
#[derive(ValueEnum, Clone)]
enum WhichModel {
Linear,
Mlp,
}
#[derive(Parser)]
struct Args {
#[clap(value_enum, default_value_t = WhichModel::Linear)]
model: WhichModel,
#[arg(long)]
learning_rate: Option<f64>,
}
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
// Load the dataset
let m = candle_nn::vision::mnist::load_dir("data")?;
println!("train-images: {:?}", m.train_images.shape());
println!("train-labels: {:?}", m.train_labels.shape());
println!("test-images: {:?}", m.test_images.shape());
println!("test-labels: {:?}", m.test_labels.shape());
match args.model {
WhichModel::Linear => training_loop::<LinearModel>(m, args.learning_rate.unwrap_or(1.)),
WhichModel::Mlp => training_loop::<Mlp>(m, args.learning_rate.unwrap_or(0.01)),
}
}

View File

@ -1,7 +1,7 @@
//! Variable initialization.
// This is based on:
// https://github.com/pytorch/pytorch/blob/07107919297db3f8ab37f11c12666b6d6d5f692e/torch/nn/init.py#
use candle::Shape;
use candle::{DType, Device, Result, Shape, Tensor, Var};
/// Number of features as input or output of a layer.
/// In Kaiming initialization, choosing `FanIn` preserves
@ -91,11 +91,11 @@ pub enum Init {
fan: FanInOut,
non_linearity: NonLinearity,
},
/// Orthogonal initialization
Orthogonal { gain: f64 },
}
pub const ZERO: Init = Init::Const(0.);
pub const ONE: Init = Init::Const(1.);
pub const DEFAULT_KAIMING_UNIFORM: Init = Init::Kaiming {
dist: NormalOrUniform::Uniform,
fan: FanInOut::FanIn,
@ -107,3 +107,35 @@ pub const DEFAULT_KAIMING_NORMAL: Init = Init::Kaiming {
fan: FanInOut::FanIn,
non_linearity: NonLinearity::ReLU,
};
impl Init {
/// Creates a new tensor with the specified shape, device, and initialization.
pub fn var<S: Into<Shape>>(&self, s: S, dtype: DType, device: &Device) -> Result<Var> {
match self {
Self::Const(v) if *v == 0. => Var::zeros(s, dtype, device),
Self::Const(v) if *v == 1. => Var::ones(s, dtype, device),
Self::Const(cst) => {
Var::from_tensor(&Tensor::ones(s, dtype, device)?.affine(*cst, 0.)?)
}
Self::Uniform { lo, up } => Var::rand_f64(*lo, *up, s, dtype, device),
Self::Randn { mean, stdev } => Var::randn_f64(*mean, *stdev, s, dtype, device),
Self::Kaiming {
dist,
fan,
non_linearity,
} => {
let s = s.into();
let fan = fan.for_shape(&s);
let gain = non_linearity.gain();
let std = gain / (fan as f64).sqrt();
match dist {
NormalOrUniform::Uniform => {
let bound = 3f64.sqrt() * std;
Var::rand_f64(-bound, bound, s, dtype, device)
}
NormalOrUniform::Normal => Var::randn_f64(0., std, s, dtype, device),
}
}
}
}
}

View File

@ -15,6 +15,7 @@ pub mod vision;
pub use activation::Activation;
pub use conv::{Conv1d, Conv1dConfig};
pub use embedding::Embedding;
pub use init::Init;
pub use layer_norm::LayerNorm;
pub use linear::Linear;
pub use optim::SGD;