mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Improve the mnist training example. (#276)
* Improve the mnist training example. * Add some initialization routine that can be used for nn. * Proper initialization in the mnist example.
This commit is contained in:
@ -116,21 +116,48 @@ impl Device {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn rand_uniform_f64(
|
||||||
|
&self,
|
||||||
|
lo: f64,
|
||||||
|
up: f64,
|
||||||
|
shape: &Shape,
|
||||||
|
dtype: DType,
|
||||||
|
) -> Result<Storage> {
|
||||||
|
match self {
|
||||||
|
Device::Cpu => {
|
||||||
|
let storage = CpuDevice.rand_uniform(shape, dtype, lo, up)?;
|
||||||
|
Ok(Storage::Cpu(storage))
|
||||||
|
}
|
||||||
|
Device::Cuda(device) => {
|
||||||
|
let storage = device.rand_uniform(shape, dtype, lo, up)?;
|
||||||
|
Ok(Storage::Cuda(storage))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn rand_uniform<T: crate::FloatDType>(
|
pub(crate) fn rand_uniform<T: crate::FloatDType>(
|
||||||
&self,
|
&self,
|
||||||
lo: T,
|
lo: T,
|
||||||
up: T,
|
up: T,
|
||||||
shape: &Shape,
|
shape: &Shape,
|
||||||
) -> Result<Storage> {
|
) -> Result<Storage> {
|
||||||
let lo = lo.to_f64();
|
self.rand_uniform_f64(lo.to_f64(), up.to_f64(), shape, T::DTYPE)
|
||||||
let up = up.to_f64();
|
}
|
||||||
|
|
||||||
|
pub(crate) fn rand_normal_f64(
|
||||||
|
&self,
|
||||||
|
mean: f64,
|
||||||
|
std: f64,
|
||||||
|
shape: &Shape,
|
||||||
|
dtype: DType,
|
||||||
|
) -> Result<Storage> {
|
||||||
match self {
|
match self {
|
||||||
Device::Cpu => {
|
Device::Cpu => {
|
||||||
let storage = CpuDevice.rand_uniform(shape, T::DTYPE, lo, up)?;
|
let storage = CpuDevice.rand_normal(shape, dtype, mean, std)?;
|
||||||
Ok(Storage::Cpu(storage))
|
Ok(Storage::Cpu(storage))
|
||||||
}
|
}
|
||||||
Device::Cuda(device) => {
|
Device::Cuda(device) => {
|
||||||
let storage = device.rand_uniform(shape, T::DTYPE, lo, up)?;
|
let storage = device.rand_normal(shape, dtype, mean, std)?;
|
||||||
Ok(Storage::Cuda(storage))
|
Ok(Storage::Cuda(storage))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -142,18 +169,7 @@ impl Device {
|
|||||||
std: T,
|
std: T,
|
||||||
shape: &Shape,
|
shape: &Shape,
|
||||||
) -> Result<Storage> {
|
) -> Result<Storage> {
|
||||||
let mean = mean.to_f64();
|
self.rand_normal_f64(mean.to_f64(), std.to_f64(), shape, T::DTYPE)
|
||||||
let std = std.to_f64();
|
|
||||||
match self {
|
|
||||||
Device::Cpu => {
|
|
||||||
let storage = CpuDevice.rand_normal(shape, T::DTYPE, mean, std)?;
|
|
||||||
Ok(Storage::Cpu(storage))
|
|
||||||
}
|
|
||||||
Device::Cuda(device) => {
|
|
||||||
let storage = device.rand_normal(shape, T::DTYPE, mean, std)?;
|
|
||||||
Ok(Storage::Cuda(storage))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn ones(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
|
pub(crate) fn ones(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
|
||||||
|
@ -245,6 +245,20 @@ impl Tensor {
|
|||||||
Ok(from_storage(storage, s, none, is_variable))
|
Ok(from_storage(storage, s, none, is_variable))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn rand_f64_impl<S: Into<Shape>>(
|
||||||
|
lo: f64,
|
||||||
|
up: f64,
|
||||||
|
s: S,
|
||||||
|
dtype: DType,
|
||||||
|
device: &Device,
|
||||||
|
is_variable: bool,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let s = s.into();
|
||||||
|
let storage = device.rand_uniform_f64(lo, up, &s, dtype)?;
|
||||||
|
let none = BackpropOp::none();
|
||||||
|
Ok(from_storage(storage, s, none, is_variable))
|
||||||
|
}
|
||||||
|
|
||||||
/// Creates a new tensor initialized with values sampled uniformly between `lo` and `up`.
|
/// Creates a new tensor initialized with values sampled uniformly between `lo` and `up`.
|
||||||
pub fn rand<S: Into<Shape>, T: crate::FloatDType>(
|
pub fn rand<S: Into<Shape>, T: crate::FloatDType>(
|
||||||
lo: T,
|
lo: T,
|
||||||
@ -268,6 +282,20 @@ impl Tensor {
|
|||||||
Ok(from_storage(storage, s, none, is_variable))
|
Ok(from_storage(storage, s, none, is_variable))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn randn_f64_impl<S: Into<Shape>>(
|
||||||
|
mean: f64,
|
||||||
|
std: f64,
|
||||||
|
s: S,
|
||||||
|
dtype: DType,
|
||||||
|
device: &Device,
|
||||||
|
is_variable: bool,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let s = s.into();
|
||||||
|
let storage = device.rand_normal_f64(mean, std, &s, dtype)?;
|
||||||
|
let none = BackpropOp::none();
|
||||||
|
Ok(from_storage(storage, s, none, is_variable))
|
||||||
|
}
|
||||||
|
|
||||||
/// Creates a new tensor initialized with values sampled from a normal distribution with the
|
/// Creates a new tensor initialized with values sampled from a normal distribution with the
|
||||||
/// specified `mean` and standard deviation `std`.
|
/// specified `mean` and standard deviation `std`.
|
||||||
pub fn randn<S: Into<Shape>, T: crate::FloatDType>(
|
pub fn randn<S: Into<Shape>, T: crate::FloatDType>(
|
||||||
@ -1448,6 +1476,16 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Create a variable based on the values currently stored in a tensor. The storage is always
|
||||||
|
/// copied.
|
||||||
|
pub(crate) fn make_var(&self) -> Result<Tensor> {
|
||||||
|
let shape = self.shape().clone();
|
||||||
|
let mut storage = self.device().zeros(&shape, self.dtype())?;
|
||||||
|
self.storage()
|
||||||
|
.copy_strided_src(&mut storage, 0, self.layout())?;
|
||||||
|
Ok(from_storage(storage, shape, BackpropOp::none(), true))
|
||||||
|
}
|
||||||
|
|
||||||
// TODO: Do we want to allow target shape using -1 on some dimensions?
|
// TODO: Do we want to allow target shape using -1 on some dimensions?
|
||||||
/// Reshape returns a tensor with the target shape provided that the number of elements of the
|
/// Reshape returns a tensor with the target shape provided that the number of elements of the
|
||||||
/// original tensor is the same.
|
/// original tensor is the same.
|
||||||
|
@ -34,6 +34,33 @@ impl Var {
|
|||||||
Ok(Self(inner))
|
Ok(Self(inner))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn from_tensor(t: &Tensor) -> Result<Self> {
|
||||||
|
let inner = t.make_var()?;
|
||||||
|
Ok(Self(inner))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn rand_f64<S: Into<Shape>>(
|
||||||
|
lo: f64,
|
||||||
|
up: f64,
|
||||||
|
s: S,
|
||||||
|
dtype: DType,
|
||||||
|
device: &Device,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let inner = Tensor::rand_f64_impl(lo, up, s, dtype, device, true)?;
|
||||||
|
Ok(Self(inner))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn randn_f64<S: Into<Shape>>(
|
||||||
|
mean: f64,
|
||||||
|
std: f64,
|
||||||
|
s: S,
|
||||||
|
dtype: DType,
|
||||||
|
device: &Device,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let inner = Tensor::randn_f64_impl(mean, std, s, dtype, device, true)?;
|
||||||
|
Ok(Self(inner))
|
||||||
|
}
|
||||||
|
|
||||||
pub fn rand<S: Into<Shape>, T: crate::FloatDType>(
|
pub fn rand<S: Into<Shape>, T: crate::FloatDType>(
|
||||||
lo: T,
|
lo: T,
|
||||||
up: T,
|
up: T,
|
||||||
|
@ -2,8 +2,10 @@
|
|||||||
#[cfg(feature = "mkl")]
|
#[cfg(feature = "mkl")]
|
||||||
extern crate intel_mkl_src;
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
use clap::{Parser, ValueEnum};
|
||||||
|
|
||||||
use candle::{DType, Device, Result, Shape, Tensor, Var, D};
|
use candle::{DType, Device, Result, Shape, Tensor, Var, D};
|
||||||
use candle_nn::{loss, ops, Linear};
|
use candle_nn::{loss, ops, Init, Linear};
|
||||||
use std::sync::{Arc, Mutex};
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
const IMAGE_DIM: usize = 784;
|
const IMAGE_DIM: usize = 784;
|
||||||
@ -44,7 +46,7 @@ impl VarStore {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get<S: Into<Shape>>(&self, shape: S, tensor_name: &str) -> Result<Tensor> {
|
fn get<S: Into<Shape>>(&self, shape: S, tensor_name: &str, init: Init) -> Result<Tensor> {
|
||||||
let shape = shape.into();
|
let shape = shape.into();
|
||||||
let path = if self.path.is_empty() {
|
let path = if self.path.is_empty() {
|
||||||
tensor_name.to_string()
|
tensor_name.to_string()
|
||||||
@ -59,8 +61,7 @@ impl VarStore {
|
|||||||
}
|
}
|
||||||
return Ok(tensor.as_tensor().clone());
|
return Ok(tensor.as_tensor().clone());
|
||||||
}
|
}
|
||||||
// TODO: Proper initialization using the `Init` enum.
|
let var = init.var(shape, tensor_data.dtype, &tensor_data.device)?;
|
||||||
let var = Var::zeros(shape, tensor_data.dtype, &tensor_data.device)?;
|
|
||||||
let tensor = var.as_tensor().clone();
|
let tensor = var.as_tensor().clone();
|
||||||
tensor_data.tensors.insert(path, var);
|
tensor_data.tensors.insert(path, var);
|
||||||
Ok(tensor)
|
Ok(tensor)
|
||||||
@ -77,21 +78,36 @@ impl VarStore {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn linear(dim1: usize, dim2: usize, vs: VarStore) -> Result<Linear> {
|
fn linear_z(in_dim: usize, out_dim: usize, vs: VarStore) -> Result<Linear> {
|
||||||
let ws = vs.get((dim2, dim1), "weight")?;
|
let ws = vs.get((out_dim, in_dim), "weight", candle_nn::init::ZERO)?;
|
||||||
let bs = vs.get(dim2, "bias")?;
|
let bs = vs.get(out_dim, "bias", candle_nn::init::ZERO)?;
|
||||||
Ok(Linear::new(ws, Some(bs)))
|
Ok(Linear::new(ws, Some(bs)))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(unused)]
|
fn linear(in_dim: usize, out_dim: usize, vs: VarStore) -> Result<Linear> {
|
||||||
|
let init_ws = candle_nn::init::DEFAULT_KAIMING_NORMAL;
|
||||||
|
let ws = vs.get((out_dim, in_dim), "weight", init_ws)?;
|
||||||
|
let bound = 1. / (in_dim as f64).sqrt();
|
||||||
|
let init_bs = Init::Uniform {
|
||||||
|
lo: -bound,
|
||||||
|
up: bound,
|
||||||
|
};
|
||||||
|
let bs = vs.get(out_dim, "bias", init_bs)?;
|
||||||
|
Ok(Linear::new(ws, Some(bs)))
|
||||||
|
}
|
||||||
|
|
||||||
|
trait Model: Sized {
|
||||||
|
fn new(vs: VarStore) -> Result<Self>;
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor>;
|
||||||
|
}
|
||||||
|
|
||||||
struct LinearModel {
|
struct LinearModel {
|
||||||
linear: Linear,
|
linear: Linear,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(unused)]
|
impl Model for LinearModel {
|
||||||
impl LinearModel {
|
|
||||||
fn new(vs: VarStore) -> Result<Self> {
|
fn new(vs: VarStore) -> Result<Self> {
|
||||||
let linear = linear(IMAGE_DIM, LABELS, vs)?;
|
let linear = linear_z(IMAGE_DIM, LABELS, vs)?;
|
||||||
Ok(Self { linear })
|
Ok(Self { linear })
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -100,14 +116,12 @@ impl LinearModel {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(unused)]
|
|
||||||
struct Mlp {
|
struct Mlp {
|
||||||
ln1: Linear,
|
ln1: Linear,
|
||||||
ln2: Linear,
|
ln2: Linear,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(unused)]
|
impl Model for Mlp {
|
||||||
impl Mlp {
|
|
||||||
fn new(vs: VarStore) -> Result<Self> {
|
fn new(vs: VarStore) -> Result<Self> {
|
||||||
let ln1 = linear(IMAGE_DIM, 100, vs.pp("ln1"))?;
|
let ln1 = linear(IMAGE_DIM, 100, vs.pp("ln1"))?;
|
||||||
let ln2 = linear(100, LABELS, vs.pp("ln2"))?;
|
let ln2 = linear(100, LABELS, vs.pp("ln2"))?;
|
||||||
@ -121,26 +135,22 @@ impl Mlp {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn main() -> anyhow::Result<()> {
|
fn training_loop<M: Model>(
|
||||||
|
m: candle_nn::vision::Dataset,
|
||||||
|
learning_rate: f64,
|
||||||
|
) -> anyhow::Result<()> {
|
||||||
let dev = candle::Device::cuda_if_available(0)?;
|
let dev = candle::Device::cuda_if_available(0)?;
|
||||||
|
|
||||||
// Load the dataset
|
|
||||||
let m = candle_nn::vision::mnist::load_dir("data")?;
|
|
||||||
println!("train-images: {:?}", m.train_images.shape());
|
|
||||||
println!("train-labels: {:?}", m.train_labels.shape());
|
|
||||||
println!("test-images: {:?}", m.test_images.shape());
|
|
||||||
println!("test-labels: {:?}", m.test_labels.shape());
|
|
||||||
let train_labels = m.train_labels;
|
let train_labels = m.train_labels;
|
||||||
let train_images = m.train_images;
|
let train_images = m.train_images;
|
||||||
let train_labels = train_labels.to_dtype(DType::U32)?.unsqueeze(1)?;
|
let train_labels = train_labels.to_dtype(DType::U32)?.unsqueeze(1)?;
|
||||||
|
|
||||||
let vs = VarStore::new(DType::F32, dev);
|
let vs = VarStore::new(DType::F32, dev);
|
||||||
let model = LinearModel::new(vs.clone())?;
|
let model = M::new(vs.clone())?;
|
||||||
// let model = Mlp::new(vs)?;
|
|
||||||
|
|
||||||
let all_vars = vs.all_vars();
|
let all_vars = vs.all_vars();
|
||||||
let all_vars = all_vars.iter().collect::<Vec<_>>();
|
let all_vars = all_vars.iter().collect::<Vec<_>>();
|
||||||
let sgd = candle_nn::SGD::new(&all_vars, 1.0);
|
let sgd = candle_nn::SGD::new(&all_vars, learning_rate);
|
||||||
let test_images = m.test_images;
|
let test_images = m.test_images;
|
||||||
let test_labels = m.test_labels.to_dtype(DType::U32)?;
|
let test_labels = m.test_labels.to_dtype(DType::U32)?;
|
||||||
for epoch in 1..200 {
|
for epoch in 1..200 {
|
||||||
@ -165,3 +175,33 @@ pub fn main() -> anyhow::Result<()> {
|
|||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(ValueEnum, Clone)]
|
||||||
|
enum WhichModel {
|
||||||
|
Linear,
|
||||||
|
Mlp,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Parser)]
|
||||||
|
struct Args {
|
||||||
|
#[clap(value_enum, default_value_t = WhichModel::Linear)]
|
||||||
|
model: WhichModel,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
learning_rate: Option<f64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn main() -> anyhow::Result<()> {
|
||||||
|
let args = Args::parse();
|
||||||
|
// Load the dataset
|
||||||
|
let m = candle_nn::vision::mnist::load_dir("data")?;
|
||||||
|
println!("train-images: {:?}", m.train_images.shape());
|
||||||
|
println!("train-labels: {:?}", m.train_labels.shape());
|
||||||
|
println!("test-images: {:?}", m.test_images.shape());
|
||||||
|
println!("test-labels: {:?}", m.test_labels.shape());
|
||||||
|
|
||||||
|
match args.model {
|
||||||
|
WhichModel::Linear => training_loop::<LinearModel>(m, args.learning_rate.unwrap_or(1.)),
|
||||||
|
WhichModel::Mlp => training_loop::<Mlp>(m, args.learning_rate.unwrap_or(0.01)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
//! Variable initialization.
|
//! Variable initialization.
|
||||||
// This is based on:
|
// This is based on:
|
||||||
// https://github.com/pytorch/pytorch/blob/07107919297db3f8ab37f11c12666b6d6d5f692e/torch/nn/init.py#
|
// https://github.com/pytorch/pytorch/blob/07107919297db3f8ab37f11c12666b6d6d5f692e/torch/nn/init.py#
|
||||||
use candle::Shape;
|
use candle::{DType, Device, Result, Shape, Tensor, Var};
|
||||||
|
|
||||||
/// Number of features as input or output of a layer.
|
/// Number of features as input or output of a layer.
|
||||||
/// In Kaiming initialization, choosing `FanIn` preserves
|
/// In Kaiming initialization, choosing `FanIn` preserves
|
||||||
@ -91,11 +91,11 @@ pub enum Init {
|
|||||||
fan: FanInOut,
|
fan: FanInOut,
|
||||||
non_linearity: NonLinearity,
|
non_linearity: NonLinearity,
|
||||||
},
|
},
|
||||||
|
|
||||||
/// Orthogonal initialization
|
|
||||||
Orthogonal { gain: f64 },
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub const ZERO: Init = Init::Const(0.);
|
||||||
|
pub const ONE: Init = Init::Const(1.);
|
||||||
|
|
||||||
pub const DEFAULT_KAIMING_UNIFORM: Init = Init::Kaiming {
|
pub const DEFAULT_KAIMING_UNIFORM: Init = Init::Kaiming {
|
||||||
dist: NormalOrUniform::Uniform,
|
dist: NormalOrUniform::Uniform,
|
||||||
fan: FanInOut::FanIn,
|
fan: FanInOut::FanIn,
|
||||||
@ -107,3 +107,35 @@ pub const DEFAULT_KAIMING_NORMAL: Init = Init::Kaiming {
|
|||||||
fan: FanInOut::FanIn,
|
fan: FanInOut::FanIn,
|
||||||
non_linearity: NonLinearity::ReLU,
|
non_linearity: NonLinearity::ReLU,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
impl Init {
|
||||||
|
/// Creates a new tensor with the specified shape, device, and initialization.
|
||||||
|
pub fn var<S: Into<Shape>>(&self, s: S, dtype: DType, device: &Device) -> Result<Var> {
|
||||||
|
match self {
|
||||||
|
Self::Const(v) if *v == 0. => Var::zeros(s, dtype, device),
|
||||||
|
Self::Const(v) if *v == 1. => Var::ones(s, dtype, device),
|
||||||
|
Self::Const(cst) => {
|
||||||
|
Var::from_tensor(&Tensor::ones(s, dtype, device)?.affine(*cst, 0.)?)
|
||||||
|
}
|
||||||
|
Self::Uniform { lo, up } => Var::rand_f64(*lo, *up, s, dtype, device),
|
||||||
|
Self::Randn { mean, stdev } => Var::randn_f64(*mean, *stdev, s, dtype, device),
|
||||||
|
Self::Kaiming {
|
||||||
|
dist,
|
||||||
|
fan,
|
||||||
|
non_linearity,
|
||||||
|
} => {
|
||||||
|
let s = s.into();
|
||||||
|
let fan = fan.for_shape(&s);
|
||||||
|
let gain = non_linearity.gain();
|
||||||
|
let std = gain / (fan as f64).sqrt();
|
||||||
|
match dist {
|
||||||
|
NormalOrUniform::Uniform => {
|
||||||
|
let bound = 3f64.sqrt() * std;
|
||||||
|
Var::rand_f64(-bound, bound, s, dtype, device)
|
||||||
|
}
|
||||||
|
NormalOrUniform::Normal => Var::randn_f64(0., std, s, dtype, device),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -15,6 +15,7 @@ pub mod vision;
|
|||||||
pub use activation::Activation;
|
pub use activation::Activation;
|
||||||
pub use conv::{Conv1d, Conv1dConfig};
|
pub use conv::{Conv1d, Conv1dConfig};
|
||||||
pub use embedding::Embedding;
|
pub use embedding::Embedding;
|
||||||
|
pub use init::Init;
|
||||||
pub use layer_norm::LayerNorm;
|
pub use layer_norm::LayerNorm;
|
||||||
pub use linear::Linear;
|
pub use linear::Linear;
|
||||||
pub use optim::SGD;
|
pub use optim::SGD;
|
||||||
|
Reference in New Issue
Block a user