mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Add an abstract type for RmsNorm. (#499)
This commit is contained in:
@ -140,11 +140,29 @@ pub fn layer_norm<C: Into<LayerNormConfig>>(
|
||||
})
|
||||
}
|
||||
|
||||
pub fn rms_norm(size: usize, eps: f64, vb: crate::VarBuilder) -> Result<LayerNorm> {
|
||||
/// RmsNorm is a specialized version of the LayerNorm module.
|
||||
#[derive(Debug)]
|
||||
pub struct RmsNorm(LayerNorm);
|
||||
|
||||
impl RmsNorm {
|
||||
pub fn new(weight: Tensor, eps: f64) -> Self {
|
||||
Self(LayerNorm::rms_norm(weight, eps))
|
||||
}
|
||||
|
||||
pub fn into_inner(self) -> LayerNorm {
|
||||
self.0
|
||||
}
|
||||
|
||||
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
self.0.forward(xs)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn rms_norm(size: usize, eps: f64, vb: crate::VarBuilder) -> Result<RmsNorm> {
|
||||
let config = LayerNormConfig {
|
||||
eps,
|
||||
remove_mean: false,
|
||||
affine: false,
|
||||
};
|
||||
layer_norm(size, config, vb)
|
||||
Ok(RmsNorm(layer_norm(size, config, vb)?))
|
||||
}
|
||||
|
@ -17,7 +17,7 @@ pub use conv::{conv1d, conv2d, Conv1d, Conv1dConfig, Conv2d, Conv2dConfig};
|
||||
pub use embedding::{embedding, Embedding};
|
||||
pub use group_norm::{group_norm, GroupNorm};
|
||||
pub use init::Init;
|
||||
pub use layer_norm::{layer_norm, rms_norm, LayerNorm, LayerNormConfig};
|
||||
pub use layer_norm::{layer_norm, rms_norm, LayerNorm, LayerNormConfig, RmsNorm};
|
||||
pub use linear::{linear, linear_no_bias, Linear};
|
||||
pub use optim::{AdamW, ParamsAdamW, SGD};
|
||||
pub use var_builder::{VarBuilder, VarMap};
|
||||
|
Reference in New Issue
Block a user