mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Share the layer-norm implementation. (#1248)
This commit is contained in:
@ -124,3 +124,34 @@ impl std::fmt::Debug for QMatMul {
|
||||
write!(f, "QMatMul")
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct LayerNorm {
|
||||
inner: candle_nn::LayerNorm,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl LayerNorm {
|
||||
pub fn new(weight: Tensor, bias: Tensor, eps: f64) -> Self {
|
||||
let inner = candle_nn::LayerNorm::new(weight, bias, eps);
|
||||
let span = tracing::span!(tracing::Level::TRACE, "layer-norm");
|
||||
Self { inner, span }
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for LayerNorm {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
self.inner.forward(xs)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn layer_norm<C: Into<candle_nn::LayerNormConfig>>(
|
||||
size: usize,
|
||||
c: C,
|
||||
vb: VarBuilder,
|
||||
) -> Result<LayerNorm> {
|
||||
let inner = candle_nn::layer_norm(size, c, vb)?;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "layer-norm");
|
||||
Ok(LayerNorm { inner, span })
|
||||
}
|
||||
|
Reference in New Issue
Block a user