Add an abstract type for RmsNorm. (#499)

This commit is contained in:
Laurent Mazare
2023-08-18 08:52:14 +01:00
committed by GitHub
parent a22b1bed7b
commit 13401df4d1
8 changed files with 45 additions and 24 deletions

View File

@ -140,11 +140,29 @@ pub fn layer_norm<C: Into<LayerNormConfig>>(
})
}
pub fn rms_norm(size: usize, eps: f64, vb: crate::VarBuilder) -> Result<LayerNorm> {
/// RmsNorm is a specialized version of the LayerNorm module.
#[derive(Debug)]
pub struct RmsNorm(LayerNorm);
impl RmsNorm {
pub fn new(weight: Tensor, eps: f64) -> Self {
Self(LayerNorm::rms_norm(weight, eps))
}
pub fn into_inner(self) -> LayerNorm {
self.0
}
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
self.0.forward(xs)
}
}
pub fn rms_norm(size: usize, eps: f64, vb: crate::VarBuilder) -> Result<RmsNorm> {
let config = LayerNormConfig {
eps,
remove_mean: false,
affine: false,
};
layer_norm(size, config, vb)
Ok(RmsNorm(layer_norm(size, config, vb)?))
}

View File

@ -17,7 +17,7 @@ pub use conv::{conv1d, conv2d, Conv1d, Conv1dConfig, Conv2d, Conv2dConfig};
pub use embedding::{embedding, Embedding};
pub use group_norm::{group_norm, GroupNorm};
pub use init::Init;
pub use layer_norm::{layer_norm, rms_norm, LayerNorm, LayerNormConfig};
pub use layer_norm::{layer_norm, rms_norm, LayerNorm, LayerNormConfig, RmsNorm};
pub use linear::{linear, linear_no_bias, Linear};
pub use optim::{AdamW, ParamsAdamW, SGD};
pub use var_builder::{VarBuilder, VarMap};