Use a common with_tracing::RmsNorm in a few models. (#1871)

* Add RmsNorm with tracing.

* Use with_tracing::RmsNorm in some models.
This commit is contained in:
Jani Monoses
2024-03-18 22:40:06 +02:00
committed by GitHub
parent 6a966cf9e0
commit 90fc82211f
6 changed files with 29 additions and 111 deletions

View File

@ -167,3 +167,24 @@ pub fn layer_norm<C: Into<candle_nn::LayerNormConfig>>(
let span = tracing::span!(tracing::Level::TRACE, "layer-norm");
Ok(LayerNorm { inner, span })
}
#[derive(Debug, Clone)]
pub struct RmsNorm {
inner: candle_nn::RmsNorm,
span: tracing::Span,
}
impl RmsNorm {
pub fn new(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
let inner = candle_nn::rms_norm(size, eps, vb)?;
Ok(Self { inner, span })
}
}
impl Module for RmsNorm {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
self.inner.forward(x)
}
}