mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 20:09:50 +00:00
Use a common with_tracing::RmsNorm in a few models. (#1871)
* Add RmsNorm with tracing. * Use with_tracing::RmsNorm in some models.
This commit is contained in:
@ -167,3 +167,24 @@ pub fn layer_norm<C: Into<candle_nn::LayerNormConfig>>(
|
||||
let span = tracing::span!(tracing::Level::TRACE, "layer-norm");
|
||||
Ok(LayerNorm { inner, span })
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RmsNorm {
|
||||
inner: candle_nn::RmsNorm,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl RmsNorm {
|
||||
pub fn new(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
|
||||
let inner = candle_nn::rms_norm(size, eps, vb)?;
|
||||
Ok(Self { inner, span })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for RmsNorm {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
self.inner.forward(x)
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user