//! Layer Normalization. //! //! This layer applies Layer Normalization over a mini-batch of inputs as described in [`Layer //! Normalization`]. The input is expected to have three dimensions: a batch dimension, a length, //! and a hidden size, the normalization is applied over the last dimension. //! //! # Example //! //! ```rust //! use candle::{Tensor, Device::Cpu}; //! use candle_nn::LayerNorm; //! # fn main() -> candle::Result<()> { //! //! let w = Tensor::new(1f32, &Cpu)?; //! let b = Tensor::new(0f32, &Cpu)?; //! let layer = LayerNorm::new(w, b, 1e-5); //! //! let xs = Tensor::new( //! &[[[1f32, 2., 3.], [4., 5., 6.], [9., 8., 7.]]], //! &Cpu)?; //! let ys = layer.forward(&xs)?; //! assert_eq!( //! ys.to_vec3::()?, //! &[[[-1.2247356, 0.0, 1.2247356], //! [-1.2247356, 0.0, 1.2247356], //! [ 1.2247356, 0.0, -1.2247356]]]); //! # Ok(()) } //! ``` //! //! [`Layer Normalization`]: https://arxiv.org/abs/1607.06450 use candle::{DType, Result, Tensor}; // This layer norm version handles both weight and bias so removes the mean. #[derive(Debug)] pub struct LayerNorm { weight: Tensor, bias: Tensor, eps: f64, } impl LayerNorm { pub fn new(weight: Tensor, bias: Tensor, eps: f64) -> Self { Self { weight, bias, eps } } pub fn forward(&self, x: &Tensor) -> Result { let x_dtype = x.dtype(); let internal_dtype = match x_dtype { DType::F16 | DType::BF16 => DType::F32, d => d, }; let (_bsize, _seq_len, hidden_size) = x.dims3()?; let x = x.to_dtype(internal_dtype)?; let mean_x = (x.sum_keepdim(2)? / hidden_size as f64)?; let x = x.broadcast_sub(&mean_x)?; let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?; let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?; let x = x_normed .to_dtype(x_dtype)? .broadcast_mul(&self.weight)? .broadcast_add(&self.bias)?; Ok(x) } } pub fn layer_norm(size: usize, eps: f64, vb: crate::VarBuilder) -> Result { let weight = vb.get_or_init(size, "weight", crate::Init::Const(1.))?; let bias = vb.get_or_init(size, "bias", crate::Init::Const(0.))?; Ok(LayerNorm::new(weight, bias, eps)) }