mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00

* Rework the var-builder to handle initializations. * Add some helper functions for layer creation. * Improve the layer initializations. * Get initialized variables. * Precompute the rot embeddings when training lamas.
71 lines
2.3 KiB
Rust
71 lines
2.3 KiB
Rust
//! Layer Normalization.
|
|
//!
|
|
//! This layer applies Layer Normalization over a mini-batch of inputs as described in [`Layer
|
|
//! Normalization`]. The input is expected to have three dimensions: a batch dimension, a length,
|
|
//! and a hidden size, the normalization is applied over the last dimension.
|
|
//!
|
|
//! # Example
|
|
//!
|
|
//! ```rust
|
|
//! use candle::{Tensor, Device::Cpu};
|
|
//! use candle_nn::LayerNorm;
|
|
//! # fn main() -> candle::Result<()> {
|
|
//!
|
|
//! let w = Tensor::new(1f32, &Cpu)?;
|
|
//! let b = Tensor::new(0f32, &Cpu)?;
|
|
//! let layer = LayerNorm::new(w, b, 1e-5);
|
|
//!
|
|
//! let xs = Tensor::new(
|
|
//! &[[[1f32, 2., 3.], [4., 5., 6.], [9., 8., 7.]]],
|
|
//! &Cpu)?;
|
|
//! let ys = layer.forward(&xs)?;
|
|
//! assert_eq!(
|
|
//! ys.to_vec3::<f32>()?,
|
|
//! &[[[-1.2247356, 0.0, 1.2247356],
|
|
//! [-1.2247356, 0.0, 1.2247356],
|
|
//! [ 1.2247356, 0.0, -1.2247356]]]);
|
|
//! # Ok(()) }
|
|
//! ```
|
|
//!
|
|
//! [`Layer Normalization`]: https://arxiv.org/abs/1607.06450
|
|
use candle::{DType, Result, Tensor};
|
|
|
|
// This layer norm version handles both weight and bias so removes the mean.
|
|
#[derive(Debug)]
|
|
pub struct LayerNorm {
|
|
weight: Tensor,
|
|
bias: Tensor,
|
|
eps: f64,
|
|
}
|
|
|
|
impl LayerNorm {
|
|
pub fn new(weight: Tensor, bias: Tensor, eps: f64) -> Self {
|
|
Self { weight, bias, eps }
|
|
}
|
|
|
|
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
|
let x_dtype = x.dtype();
|
|
let internal_dtype = match x_dtype {
|
|
DType::F16 | DType::BF16 => DType::F32,
|
|
d => d,
|
|
};
|
|
let (_bsize, _seq_len, hidden_size) = x.dims3()?;
|
|
let x = x.to_dtype(internal_dtype)?;
|
|
let mean_x = (x.sum_keepdim(2)? / hidden_size as f64)?;
|
|
let x = x.broadcast_sub(&mean_x)?;
|
|
let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?;
|
|
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
|
|
let x = x_normed
|
|
.to_dtype(x_dtype)?
|
|
.broadcast_mul(&self.weight)?
|
|
.broadcast_add(&self.bias)?;
|
|
Ok(x)
|
|
}
|
|
}
|
|
|
|
pub fn layer_norm(size: usize, eps: f64, vb: crate::VarBuilder) -> Result<LayerNorm> {
|
|
let weight = vb.get_or_init(size, "weight", crate::Init::Const(1.))?;
|
|
let bias = vb.get_or_init(size, "bias", crate::Init::Const(0.))?;
|
|
Ok(LayerNorm::new(weight, bias, eps))
|
|
}
|