mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Sketch the tensor initialization module. (#134)
This commit is contained in:
109
candle-nn/src/init.rs
Normal file
109
candle-nn/src/init.rs
Normal file
@ -0,0 +1,109 @@
|
|||||||
|
//! Variable initialization.
|
||||||
|
// This is based on:
|
||||||
|
// https://github.com/pytorch/pytorch/blob/07107919297db3f8ab37f11c12666b6d6d5f692e/torch/nn/init.py#
|
||||||
|
use candle::Shape;
|
||||||
|
|
||||||
|
/// Number of features as input or output of a layer.
|
||||||
|
/// In Kaiming initialization, choosing `FanIn` preserves
|
||||||
|
/// the magnitude of the variance of the weights in the
|
||||||
|
/// forward pass, choosing `FanOut` preserves this
|
||||||
|
/// magnitude in the backward pass.
|
||||||
|
#[derive(Debug, Copy, Clone)]
|
||||||
|
pub enum FanInOut {
|
||||||
|
FanIn,
|
||||||
|
FanOut,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FanInOut {
|
||||||
|
/// Compute the fan-in or fan-out value for a weight tensor of
|
||||||
|
/// the specified dimensions.
|
||||||
|
/// <https://github.com/pytorch/pytorch/blob/dbeacf11820e336e803bb719b7aaaf2125ae4d9c/torch/nn/init.py#L284>
|
||||||
|
pub fn for_shape(&self, shape: &Shape) -> usize {
|
||||||
|
let dims = shape.dims();
|
||||||
|
let receptive_field_size: usize = dims.iter().skip(2).product();
|
||||||
|
match &self {
|
||||||
|
FanInOut::FanIn => {
|
||||||
|
if dims.len() < 2 {
|
||||||
|
1
|
||||||
|
} else {
|
||||||
|
dims[1] * receptive_field_size
|
||||||
|
}
|
||||||
|
}
|
||||||
|
FanInOut::FanOut => {
|
||||||
|
if dims.is_empty() {
|
||||||
|
1
|
||||||
|
} else {
|
||||||
|
dims[0] * receptive_field_size
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Copy, Clone)]
|
||||||
|
pub enum NormalOrUniform {
|
||||||
|
Normal,
|
||||||
|
Uniform,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The non-linear function that follows this layer. ReLU is the
|
||||||
|
/// recommended value.
|
||||||
|
#[derive(Debug, Copy, Clone)]
|
||||||
|
pub enum NonLinearity {
|
||||||
|
ReLU,
|
||||||
|
Linear,
|
||||||
|
Sigmoid,
|
||||||
|
Tanh,
|
||||||
|
SELU,
|
||||||
|
ExplicitGain(f64),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl NonLinearity {
|
||||||
|
// https://github.com/pytorch/pytorch/blob/07107919297db3f8ab37f11c12666b6d6d5f692e/torch/nn/init.py#L67
|
||||||
|
pub fn gain(&self) -> f64 {
|
||||||
|
match *self {
|
||||||
|
NonLinearity::ReLU => 2f64.sqrt(),
|
||||||
|
NonLinearity::Tanh => 5. / 3.,
|
||||||
|
NonLinearity::Linear | NonLinearity::Sigmoid => 1.,
|
||||||
|
NonLinearity::SELU => 0.75,
|
||||||
|
NonLinearity::ExplicitGain(g) => g,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Variable initializations.
|
||||||
|
#[derive(Debug, Copy, Clone)]
|
||||||
|
pub enum Init {
|
||||||
|
/// Constant value.
|
||||||
|
Const(f64),
|
||||||
|
|
||||||
|
/// Random normal with some mean and standard deviation.
|
||||||
|
Randn { mean: f64, stdev: f64 },
|
||||||
|
|
||||||
|
/// Uniform initialization between some lower and upper bounds.
|
||||||
|
Uniform { lo: f64, up: f64 },
|
||||||
|
|
||||||
|
/// Kaiming uniform initialization.
|
||||||
|
/// See "Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification"
|
||||||
|
/// He, K. et al. (2015). This uses a uniform distribution.
|
||||||
|
Kaiming {
|
||||||
|
dist: NormalOrUniform,
|
||||||
|
fan: FanInOut,
|
||||||
|
non_linearity: NonLinearity,
|
||||||
|
},
|
||||||
|
|
||||||
|
/// Orthogonal initialization
|
||||||
|
Orthogonal { gain: f64 },
|
||||||
|
}
|
||||||
|
|
||||||
|
pub const DEFAULT_KAIMING_UNIFORM: Init = Init::Kaiming {
|
||||||
|
dist: NormalOrUniform::Uniform,
|
||||||
|
fan: FanInOut::FanIn,
|
||||||
|
non_linearity: NonLinearity::ReLU,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub const DEFAULT_KAIMING_NORMAL: Init = Init::Kaiming {
|
||||||
|
dist: NormalOrUniform::Normal,
|
||||||
|
fan: FanInOut::FanIn,
|
||||||
|
non_linearity: NonLinearity::ReLU,
|
||||||
|
};
|
@ -1,11 +1,12 @@
|
|||||||
// For now this crate shares its error type with candle-core. We may introduce some separate
|
// For now this crate shares its error type with candle-core. We may introduce some separate
|
||||||
// error type if needed or add some specialized cases on the candle-core side.
|
// error type if needed or add some specialized cases on the candle-core side.
|
||||||
mod activation;
|
pub mod activation;
|
||||||
mod conv;
|
pub mod conv;
|
||||||
mod embedding;
|
pub mod embedding;
|
||||||
mod layer_norm;
|
pub mod init;
|
||||||
mod linear;
|
pub mod layer_norm;
|
||||||
mod var_builder;
|
pub mod linear;
|
||||||
|
pub mod var_builder;
|
||||||
|
|
||||||
pub use activation::Activation;
|
pub use activation::Activation;
|
||||||
pub use conv::{Conv1d, Conv1dConfig};
|
pub use conv::{Conv1d, Conv1dConfig};
|
||||||
|
Reference in New Issue
Block a user