Sketch the tensor initialization module. (#134)

This commit is contained in:
Laurent Mazare
2023-07-11 07:41:46 +01:00
committed by GitHub
parent 0e9d3afd77
commit b31a3bbdcb
2 changed files with 116 additions and 6 deletions

109
candle-nn/src/init.rs Normal file
View File

@ -0,0 +1,109 @@
//! Variable initialization.
// This is based on:
// https://github.com/pytorch/pytorch/blob/07107919297db3f8ab37f11c12666b6d6d5f692e/torch/nn/init.py#
use candle::Shape;
/// Number of features as input or output of a layer.
/// In Kaiming initialization, choosing `FanIn` preserves
/// the magnitude of the variance of the weights in the
/// forward pass, choosing `FanOut` preserves this
/// magnitude in the backward pass.
#[derive(Debug, Copy, Clone)]
pub enum FanInOut {
FanIn,
FanOut,
}
impl FanInOut {
/// Compute the fan-in or fan-out value for a weight tensor of
/// the specified dimensions.
/// <https://github.com/pytorch/pytorch/blob/dbeacf11820e336e803bb719b7aaaf2125ae4d9c/torch/nn/init.py#L284>
pub fn for_shape(&self, shape: &Shape) -> usize {
let dims = shape.dims();
let receptive_field_size: usize = dims.iter().skip(2).product();
match &self {
FanInOut::FanIn => {
if dims.len() < 2 {
1
} else {
dims[1] * receptive_field_size
}
}
FanInOut::FanOut => {
if dims.is_empty() {
1
} else {
dims[0] * receptive_field_size
}
}
}
}
}
#[derive(Debug, Copy, Clone)]
pub enum NormalOrUniform {
Normal,
Uniform,
}
/// The non-linear function that follows this layer. ReLU is the
/// recommended value.
#[derive(Debug, Copy, Clone)]
pub enum NonLinearity {
ReLU,
Linear,
Sigmoid,
Tanh,
SELU,
ExplicitGain(f64),
}
impl NonLinearity {
// https://github.com/pytorch/pytorch/blob/07107919297db3f8ab37f11c12666b6d6d5f692e/torch/nn/init.py#L67
pub fn gain(&self) -> f64 {
match *self {
NonLinearity::ReLU => 2f64.sqrt(),
NonLinearity::Tanh => 5. / 3.,
NonLinearity::Linear | NonLinearity::Sigmoid => 1.,
NonLinearity::SELU => 0.75,
NonLinearity::ExplicitGain(g) => g,
}
}
}
/// Variable initializations.
#[derive(Debug, Copy, Clone)]
pub enum Init {
/// Constant value.
Const(f64),
/// Random normal with some mean and standard deviation.
Randn { mean: f64, stdev: f64 },
/// Uniform initialization between some lower and upper bounds.
Uniform { lo: f64, up: f64 },
/// Kaiming uniform initialization.
/// See "Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification"
/// He, K. et al. (2015). This uses a uniform distribution.
Kaiming {
dist: NormalOrUniform,
fan: FanInOut,
non_linearity: NonLinearity,
},
/// Orthogonal initialization
Orthogonal { gain: f64 },
}
pub const DEFAULT_KAIMING_UNIFORM: Init = Init::Kaiming {
dist: NormalOrUniform::Uniform,
fan: FanInOut::FanIn,
non_linearity: NonLinearity::ReLU,
};
pub const DEFAULT_KAIMING_NORMAL: Init = Init::Kaiming {
dist: NormalOrUniform::Normal,
fan: FanInOut::FanIn,
non_linearity: NonLinearity::ReLU,
};

View File

@ -1,11 +1,12 @@
// For now this crate shares its error type with candle-core. We may introduce some separate
// error type if needed or add some specialized cases on the candle-core side.
mod activation;
mod conv;
mod embedding;
mod layer_norm;
mod linear;
mod var_builder;
pub mod activation;
pub mod conv;
pub mod embedding;
pub mod init;
pub mod layer_norm;
pub mod linear;
pub mod var_builder;
pub use activation::Activation;
pub use conv::{Conv1d, Conv1dConfig};