From b31a3bbdcbf1a75bbb18cdc2aa0fbff2ab931351 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Tue, 11 Jul 2023 07:41:46 +0100 Subject: [PATCH] Sketch the tensor initialization module. (#134) --- candle-nn/src/init.rs | 109 ++++++++++++++++++++++++++++++++++++++++++ candle-nn/src/lib.rs | 13 ++--- 2 files changed, 116 insertions(+), 6 deletions(-) create mode 100644 candle-nn/src/init.rs diff --git a/candle-nn/src/init.rs b/candle-nn/src/init.rs new file mode 100644 index 00000000..762f0ef1 --- /dev/null +++ b/candle-nn/src/init.rs @@ -0,0 +1,109 @@ +//! Variable initialization. +// This is based on: +// https://github.com/pytorch/pytorch/blob/07107919297db3f8ab37f11c12666b6d6d5f692e/torch/nn/init.py# +use candle::Shape; + +/// Number of features as input or output of a layer. +/// In Kaiming initialization, choosing `FanIn` preserves +/// the magnitude of the variance of the weights in the +/// forward pass, choosing `FanOut` preserves this +/// magnitude in the backward pass. +#[derive(Debug, Copy, Clone)] +pub enum FanInOut { + FanIn, + FanOut, +} + +impl FanInOut { + /// Compute the fan-in or fan-out value for a weight tensor of + /// the specified dimensions. + /// + pub fn for_shape(&self, shape: &Shape) -> usize { + let dims = shape.dims(); + let receptive_field_size: usize = dims.iter().skip(2).product(); + match &self { + FanInOut::FanIn => { + if dims.len() < 2 { + 1 + } else { + dims[1] * receptive_field_size + } + } + FanInOut::FanOut => { + if dims.is_empty() { + 1 + } else { + dims[0] * receptive_field_size + } + } + } + } +} + +#[derive(Debug, Copy, Clone)] +pub enum NormalOrUniform { + Normal, + Uniform, +} + +/// The non-linear function that follows this layer. ReLU is the +/// recommended value. +#[derive(Debug, Copy, Clone)] +pub enum NonLinearity { + ReLU, + Linear, + Sigmoid, + Tanh, + SELU, + ExplicitGain(f64), +} + +impl NonLinearity { + // https://github.com/pytorch/pytorch/blob/07107919297db3f8ab37f11c12666b6d6d5f692e/torch/nn/init.py#L67 + pub fn gain(&self) -> f64 { + match *self { + NonLinearity::ReLU => 2f64.sqrt(), + NonLinearity::Tanh => 5. / 3., + NonLinearity::Linear | NonLinearity::Sigmoid => 1., + NonLinearity::SELU => 0.75, + NonLinearity::ExplicitGain(g) => g, + } + } +} + +/// Variable initializations. +#[derive(Debug, Copy, Clone)] +pub enum Init { + /// Constant value. + Const(f64), + + /// Random normal with some mean and standard deviation. + Randn { mean: f64, stdev: f64 }, + + /// Uniform initialization between some lower and upper bounds. + Uniform { lo: f64, up: f64 }, + + /// Kaiming uniform initialization. + /// See "Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification" + /// He, K. et al. (2015). This uses a uniform distribution. + Kaiming { + dist: NormalOrUniform, + fan: FanInOut, + non_linearity: NonLinearity, + }, + + /// Orthogonal initialization + Orthogonal { gain: f64 }, +} + +pub const DEFAULT_KAIMING_UNIFORM: Init = Init::Kaiming { + dist: NormalOrUniform::Uniform, + fan: FanInOut::FanIn, + non_linearity: NonLinearity::ReLU, +}; + +pub const DEFAULT_KAIMING_NORMAL: Init = Init::Kaiming { + dist: NormalOrUniform::Normal, + fan: FanInOut::FanIn, + non_linearity: NonLinearity::ReLU, +}; diff --git a/candle-nn/src/lib.rs b/candle-nn/src/lib.rs index efda417b..bb168661 100644 --- a/candle-nn/src/lib.rs +++ b/candle-nn/src/lib.rs @@ -1,11 +1,12 @@ // For now this crate shares its error type with candle-core. We may introduce some separate // error type if needed or add some specialized cases on the candle-core side. -mod activation; -mod conv; -mod embedding; -mod layer_norm; -mod linear; -mod var_builder; +pub mod activation; +pub mod conv; +pub mod embedding; +pub mod init; +pub mod layer_norm; +pub mod linear; +pub mod var_builder; pub use activation::Activation; pub use conv::{Conv1d, Conv1dConfig};