Share the layer-norm implementation. (#1248)

This commit is contained in:
Laurent Mazare
2023-11-03 06:30:05 +01:00
committed by GitHub
parent a2a20aeecc
commit 6975c65112
2 changed files with 32 additions and 56 deletions

View File

@ -124,3 +124,34 @@ impl std::fmt::Debug for QMatMul {
write!(f, "QMatMul")
}
}
#[derive(Clone, Debug)]
pub struct LayerNorm {
inner: candle_nn::LayerNorm,
span: tracing::Span,
}
impl LayerNorm {
pub fn new(weight: Tensor, bias: Tensor, eps: f64) -> Self {
let inner = candle_nn::LayerNorm::new(weight, bias, eps);
let span = tracing::span!(tracing::Level::TRACE, "layer-norm");
Self { inner, span }
}
}
impl Module for LayerNorm {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
self.inner.forward(xs)
}
}
pub fn layer_norm<C: Into<candle_nn::LayerNormConfig>>(
size: usize,
c: C,
vb: VarBuilder,
) -> Result<LayerNorm> {
let inner = candle_nn::layer_norm(size, c, vb)?;
let span = tracing::span!(tracing::Level::TRACE, "layer-norm");
Ok(LayerNorm { inner, span })
}