mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 04:10:46 +00:00
Improve the mnist training example. (#276)
* Improve the mnist training example. * Add some initialization routine that can be used for nn. * Proper initialization in the mnist example.
This commit is contained in:
@ -1,7 +1,7 @@
|
||||
//! Variable initialization.
|
||||
// This is based on:
|
||||
// https://github.com/pytorch/pytorch/blob/07107919297db3f8ab37f11c12666b6d6d5f692e/torch/nn/init.py#
|
||||
use candle::Shape;
|
||||
use candle::{DType, Device, Result, Shape, Tensor, Var};
|
||||
|
||||
/// Number of features as input or output of a layer.
|
||||
/// In Kaiming initialization, choosing `FanIn` preserves
|
||||
@ -91,11 +91,11 @@ pub enum Init {
|
||||
fan: FanInOut,
|
||||
non_linearity: NonLinearity,
|
||||
},
|
||||
|
||||
/// Orthogonal initialization
|
||||
Orthogonal { gain: f64 },
|
||||
}
|
||||
|
||||
pub const ZERO: Init = Init::Const(0.);
|
||||
pub const ONE: Init = Init::Const(1.);
|
||||
|
||||
pub const DEFAULT_KAIMING_UNIFORM: Init = Init::Kaiming {
|
||||
dist: NormalOrUniform::Uniform,
|
||||
fan: FanInOut::FanIn,
|
||||
@ -107,3 +107,35 @@ pub const DEFAULT_KAIMING_NORMAL: Init = Init::Kaiming {
|
||||
fan: FanInOut::FanIn,
|
||||
non_linearity: NonLinearity::ReLU,
|
||||
};
|
||||
|
||||
impl Init {
|
||||
/// Creates a new tensor with the specified shape, device, and initialization.
|
||||
pub fn var<S: Into<Shape>>(&self, s: S, dtype: DType, device: &Device) -> Result<Var> {
|
||||
match self {
|
||||
Self::Const(v) if *v == 0. => Var::zeros(s, dtype, device),
|
||||
Self::Const(v) if *v == 1. => Var::ones(s, dtype, device),
|
||||
Self::Const(cst) => {
|
||||
Var::from_tensor(&Tensor::ones(s, dtype, device)?.affine(*cst, 0.)?)
|
||||
}
|
||||
Self::Uniform { lo, up } => Var::rand_f64(*lo, *up, s, dtype, device),
|
||||
Self::Randn { mean, stdev } => Var::randn_f64(*mean, *stdev, s, dtype, device),
|
||||
Self::Kaiming {
|
||||
dist,
|
||||
fan,
|
||||
non_linearity,
|
||||
} => {
|
||||
let s = s.into();
|
||||
let fan = fan.for_shape(&s);
|
||||
let gain = non_linearity.gain();
|
||||
let std = gain / (fan as f64).sqrt();
|
||||
match dist {
|
||||
NormalOrUniform::Uniform => {
|
||||
let bound = 3f64.sqrt() * std;
|
||||
Var::rand_f64(-bound, bound, s, dtype, device)
|
||||
}
|
||||
NormalOrUniform::Normal => Var::randn_f64(0., std, s, dtype, device),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -15,6 +15,7 @@ pub mod vision;
|
||||
pub use activation::Activation;
|
||||
pub use conv::{Conv1d, Conv1dConfig};
|
||||
pub use embedding::Embedding;
|
||||
pub use init::Init;
|
||||
pub use layer_norm::LayerNorm;
|
||||
pub use linear::Linear;
|
||||
pub use optim::SGD;
|
||||
|
Reference in New Issue
Block a user