mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Layer norm tweaks (#482)
* Add some options to make layer-norm more configurable. * Add the rms-norm variant. * Replace the RmsNorm with the shared bits.
This commit is contained in:
@ -30,17 +30,70 @@
|
||||
//! [`Layer Normalization`]: https://arxiv.org/abs/1607.06450
|
||||
use candle::{DType, Result, Tensor};
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub struct LayerNormConfig {
|
||||
pub eps: f64,
|
||||
/// Whether to remove the mean or not, the default is true and when set to false, this turns
|
||||
/// this layer into RmsNorm.
|
||||
pub remove_mean: bool,
|
||||
pub affine: bool,
|
||||
}
|
||||
|
||||
impl Default for LayerNormConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
eps: 1e-5,
|
||||
remove_mean: true,
|
||||
affine: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<f64> for LayerNormConfig {
|
||||
fn from(eps: f64) -> Self {
|
||||
Self {
|
||||
eps,
|
||||
remove_mean: true,
|
||||
affine: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// This layer norm version handles both weight and bias so removes the mean.
|
||||
#[derive(Debug)]
|
||||
pub struct LayerNorm {
|
||||
weight: Tensor,
|
||||
bias: Tensor,
|
||||
bias: Option<Tensor>,
|
||||
remove_mean: bool,
|
||||
eps: f64,
|
||||
}
|
||||
|
||||
impl LayerNorm {
|
||||
pub fn new(weight: Tensor, bias: Tensor, eps: f64) -> Self {
|
||||
Self { weight, bias, eps }
|
||||
Self {
|
||||
weight,
|
||||
bias: Some(bias),
|
||||
remove_mean: true,
|
||||
eps,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_no_bias(weight: Tensor, eps: f64) -> Self {
|
||||
Self {
|
||||
weight,
|
||||
bias: None,
|
||||
remove_mean: true,
|
||||
eps,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn rms_norm(weight: Tensor, eps: f64) -> Self {
|
||||
Self {
|
||||
weight,
|
||||
bias: None,
|
||||
remove_mean: false,
|
||||
eps,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
@ -51,20 +104,47 @@ impl LayerNorm {
|
||||
};
|
||||
let (_bsize, _seq_len, hidden_size) = x.dims3()?;
|
||||
let x = x.to_dtype(internal_dtype)?;
|
||||
let mean_x = (x.sum_keepdim(2)? / hidden_size as f64)?;
|
||||
let x = x.broadcast_sub(&mean_x)?;
|
||||
let x = if self.remove_mean {
|
||||
let mean_x = (x.sum_keepdim(2)? / hidden_size as f64)?;
|
||||
x.broadcast_sub(&mean_x)?
|
||||
} else {
|
||||
x
|
||||
};
|
||||
let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?;
|
||||
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
|
||||
let x = x_normed
|
||||
.to_dtype(x_dtype)?
|
||||
.broadcast_mul(&self.weight)?
|
||||
.broadcast_add(&self.bias)?;
|
||||
Ok(x)
|
||||
let x = x_normed.to_dtype(x_dtype)?.broadcast_mul(&self.weight)?;
|
||||
match &self.bias {
|
||||
None => Ok(x),
|
||||
Some(bias) => x.broadcast_add(bias),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn layer_norm(size: usize, eps: f64, vb: crate::VarBuilder) -> Result<LayerNorm> {
|
||||
pub fn layer_norm<C: Into<LayerNormConfig>>(
|
||||
size: usize,
|
||||
config: C,
|
||||
vb: crate::VarBuilder,
|
||||
) -> Result<LayerNorm> {
|
||||
let config = config.into();
|
||||
let weight = vb.get_or_init(size, "weight", crate::Init::Const(1.))?;
|
||||
let bias = vb.get_or_init(size, "bias", crate::Init::Const(0.))?;
|
||||
Ok(LayerNorm::new(weight, bias, eps))
|
||||
let bias = if config.affine {
|
||||
Some(vb.get_or_init(size, "bias", crate::Init::Const(0.))?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(LayerNorm {
|
||||
weight,
|
||||
bias,
|
||||
remove_mean: config.remove_mean,
|
||||
eps: config.eps,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn rms_norm(size: usize, eps: f64, vb: crate::VarBuilder) -> Result<LayerNorm> {
|
||||
let config = LayerNormConfig {
|
||||
eps,
|
||||
remove_mean: false,
|
||||
affine: false,
|
||||
};
|
||||
layer_norm(size, config, vb)
|
||||
}
|
||||
|
Reference in New Issue
Block a user