mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Sketch the tensor initialization module. (#134)
This commit is contained in:
109
candle-nn/src/init.rs
Normal file
109
candle-nn/src/init.rs
Normal file
@ -0,0 +1,109 @@
|
||||
//! Variable initialization.
|
||||
// This is based on:
|
||||
// https://github.com/pytorch/pytorch/blob/07107919297db3f8ab37f11c12666b6d6d5f692e/torch/nn/init.py#
|
||||
use candle::Shape;
|
||||
|
||||
/// Number of features as input or output of a layer.
|
||||
/// In Kaiming initialization, choosing `FanIn` preserves
|
||||
/// the magnitude of the variance of the weights in the
|
||||
/// forward pass, choosing `FanOut` preserves this
|
||||
/// magnitude in the backward pass.
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub enum FanInOut {
|
||||
FanIn,
|
||||
FanOut,
|
||||
}
|
||||
|
||||
impl FanInOut {
|
||||
/// Compute the fan-in or fan-out value for a weight tensor of
|
||||
/// the specified dimensions.
|
||||
/// <https://github.com/pytorch/pytorch/blob/dbeacf11820e336e803bb719b7aaaf2125ae4d9c/torch/nn/init.py#L284>
|
||||
pub fn for_shape(&self, shape: &Shape) -> usize {
|
||||
let dims = shape.dims();
|
||||
let receptive_field_size: usize = dims.iter().skip(2).product();
|
||||
match &self {
|
||||
FanInOut::FanIn => {
|
||||
if dims.len() < 2 {
|
||||
1
|
||||
} else {
|
||||
dims[1] * receptive_field_size
|
||||
}
|
||||
}
|
||||
FanInOut::FanOut => {
|
||||
if dims.is_empty() {
|
||||
1
|
||||
} else {
|
||||
dims[0] * receptive_field_size
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub enum NormalOrUniform {
|
||||
Normal,
|
||||
Uniform,
|
||||
}
|
||||
|
||||
/// The non-linear function that follows this layer. ReLU is the
|
||||
/// recommended value.
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub enum NonLinearity {
|
||||
ReLU,
|
||||
Linear,
|
||||
Sigmoid,
|
||||
Tanh,
|
||||
SELU,
|
||||
ExplicitGain(f64),
|
||||
}
|
||||
|
||||
impl NonLinearity {
|
||||
// https://github.com/pytorch/pytorch/blob/07107919297db3f8ab37f11c12666b6d6d5f692e/torch/nn/init.py#L67
|
||||
pub fn gain(&self) -> f64 {
|
||||
match *self {
|
||||
NonLinearity::ReLU => 2f64.sqrt(),
|
||||
NonLinearity::Tanh => 5. / 3.,
|
||||
NonLinearity::Linear | NonLinearity::Sigmoid => 1.,
|
||||
NonLinearity::SELU => 0.75,
|
||||
NonLinearity::ExplicitGain(g) => g,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Variable initializations.
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub enum Init {
|
||||
/// Constant value.
|
||||
Const(f64),
|
||||
|
||||
/// Random normal with some mean and standard deviation.
|
||||
Randn { mean: f64, stdev: f64 },
|
||||
|
||||
/// Uniform initialization between some lower and upper bounds.
|
||||
Uniform { lo: f64, up: f64 },
|
||||
|
||||
/// Kaiming uniform initialization.
|
||||
/// See "Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification"
|
||||
/// He, K. et al. (2015). This uses a uniform distribution.
|
||||
Kaiming {
|
||||
dist: NormalOrUniform,
|
||||
fan: FanInOut,
|
||||
non_linearity: NonLinearity,
|
||||
},
|
||||
|
||||
/// Orthogonal initialization
|
||||
Orthogonal { gain: f64 },
|
||||
}
|
||||
|
||||
pub const DEFAULT_KAIMING_UNIFORM: Init = Init::Kaiming {
|
||||
dist: NormalOrUniform::Uniform,
|
||||
fan: FanInOut::FanIn,
|
||||
non_linearity: NonLinearity::ReLU,
|
||||
};
|
||||
|
||||
pub const DEFAULT_KAIMING_NORMAL: Init = Init::Kaiming {
|
||||
dist: NormalOrUniform::Normal,
|
||||
fan: FanInOut::FanIn,
|
||||
non_linearity: NonLinearity::ReLU,
|
||||
};
|
@ -1,11 +1,12 @@
|
||||
// For now this crate shares its error type with candle-core. We may introduce some separate
|
||||
// error type if needed or add some specialized cases on the candle-core side.
|
||||
mod activation;
|
||||
mod conv;
|
||||
mod embedding;
|
||||
mod layer_norm;
|
||||
mod linear;
|
||||
mod var_builder;
|
||||
pub mod activation;
|
||||
pub mod conv;
|
||||
pub mod embedding;
|
||||
pub mod init;
|
||||
pub mod layer_norm;
|
||||
pub mod linear;
|
||||
pub mod var_builder;
|
||||
|
||||
pub use activation::Activation;
|
||||
pub use conv::{Conv1d, Conv1dConfig};
|
||||
|
Reference in New Issue
Block a user