diff --git a/candle-transformers/src/models/bert.rs b/candle-transformers/src/models/bert.rs index aca520da..d6826a16 100644 --- a/candle-transformers/src/models/bert.rs +++ b/candle-transformers/src/models/bert.rs @@ -1,4 +1,4 @@ -use super::with_tracing::{linear, Linear}; +use super::with_tracing::{layer_norm, linear, LayerNorm, Linear}; use candle::{DType, Device, Result, Tensor}; use candle_nn::{Embedding, Module, VarBuilder}; use serde::Deserialize; @@ -33,47 +33,6 @@ impl HiddenActLayer { } } -#[derive(Debug)] -pub struct LayerNorm { - weight: Tensor, - bias: Tensor, - eps: f64, - span: tracing::Span, -} - -impl LayerNorm { - pub fn new(weight: Tensor, bias: Tensor, eps: f64) -> Self { - let span = tracing::span!(tracing::Level::TRACE, "layer-norm"); - Self { - weight, - bias, - eps, - span, - } - } -} - -impl Module for LayerNorm { - fn forward(&self, x: &Tensor) -> Result { - let _enter = self.span.enter(); - let x_dtype = x.dtype(); - let internal_dtype = match x_dtype { - DType::F16 | DType::BF16 => DType::F32, - d => d, - }; - let (_bsize, _seq_len, hidden_size) = x.dims3()?; - let x = x.to_dtype(internal_dtype)?; - let mean_x = (x.sum_keepdim(2)? / hidden_size as f64)?; - let x = x.broadcast_sub(&mean_x)?; - let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?; - let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?; - let x = x_normed - .to_dtype(x_dtype)? - .broadcast_mul(&self.weight)? - .broadcast_add(&self.bias)?; - Ok(x) - } -} #[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)] #[serde(rename_all = "lowercase")] enum PositionEmbeddingType { @@ -174,20 +133,6 @@ impl Module for Dropout { } } -fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result { - let (weight, bias) = match (vb.get(size, "weight"), vb.get(size, "bias")) { - (Ok(weight), Ok(bias)) => (weight, bias), - (Err(err), _) | (_, Err(err)) => { - if let (Ok(weight), Ok(bias)) = (vb.get(size, "gamma"), vb.get(size, "beta")) { - (weight, bias) - } else { - return Err(err); - } - } - }; - Ok(LayerNorm::new(weight, bias, eps)) -} - // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L180 struct BertEmbeddings { word_embeddings: Embedding, diff --git a/candle-transformers/src/models/with_tracing.rs b/candle-transformers/src/models/with_tracing.rs index a657011c..53e21551 100644 --- a/candle-transformers/src/models/with_tracing.rs +++ b/candle-transformers/src/models/with_tracing.rs @@ -124,3 +124,34 @@ impl std::fmt::Debug for QMatMul { write!(f, "QMatMul") } } + +#[derive(Clone, Debug)] +pub struct LayerNorm { + inner: candle_nn::LayerNorm, + span: tracing::Span, +} + +impl LayerNorm { + pub fn new(weight: Tensor, bias: Tensor, eps: f64) -> Self { + let inner = candle_nn::LayerNorm::new(weight, bias, eps); + let span = tracing::span!(tracing::Level::TRACE, "layer-norm"); + Self { inner, span } + } +} + +impl Module for LayerNorm { + fn forward(&self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); + self.inner.forward(xs) + } +} + +pub fn layer_norm>( + size: usize, + c: C, + vb: VarBuilder, +) -> Result { + let inner = candle_nn::layer_norm(size, c, vb)?; + let span = tracing::span!(tracing::Level::TRACE, "layer-norm"); + Ok(LayerNorm { inner, span }) +}