mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Share the layer-norm implementation. (#1248)
This commit is contained in:
@ -1,4 +1,4 @@
|
|||||||
use super::with_tracing::{linear, Linear};
|
use super::with_tracing::{layer_norm, linear, LayerNorm, Linear};
|
||||||
use candle::{DType, Device, Result, Tensor};
|
use candle::{DType, Device, Result, Tensor};
|
||||||
use candle_nn::{Embedding, Module, VarBuilder};
|
use candle_nn::{Embedding, Module, VarBuilder};
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
@ -33,47 +33,6 @@ impl HiddenActLayer {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct LayerNorm {
|
|
||||||
weight: Tensor,
|
|
||||||
bias: Tensor,
|
|
||||||
eps: f64,
|
|
||||||
span: tracing::Span,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl LayerNorm {
|
|
||||||
pub fn new(weight: Tensor, bias: Tensor, eps: f64) -> Self {
|
|
||||||
let span = tracing::span!(tracing::Level::TRACE, "layer-norm");
|
|
||||||
Self {
|
|
||||||
weight,
|
|
||||||
bias,
|
|
||||||
eps,
|
|
||||||
span,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Module for LayerNorm {
|
|
||||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
|
||||||
let _enter = self.span.enter();
|
|
||||||
let x_dtype = x.dtype();
|
|
||||||
let internal_dtype = match x_dtype {
|
|
||||||
DType::F16 | DType::BF16 => DType::F32,
|
|
||||||
d => d,
|
|
||||||
};
|
|
||||||
let (_bsize, _seq_len, hidden_size) = x.dims3()?;
|
|
||||||
let x = x.to_dtype(internal_dtype)?;
|
|
||||||
let mean_x = (x.sum_keepdim(2)? / hidden_size as f64)?;
|
|
||||||
let x = x.broadcast_sub(&mean_x)?;
|
|
||||||
let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?;
|
|
||||||
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
|
|
||||||
let x = x_normed
|
|
||||||
.to_dtype(x_dtype)?
|
|
||||||
.broadcast_mul(&self.weight)?
|
|
||||||
.broadcast_add(&self.bias)?;
|
|
||||||
Ok(x)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)]
|
||||||
#[serde(rename_all = "lowercase")]
|
#[serde(rename_all = "lowercase")]
|
||||||
enum PositionEmbeddingType {
|
enum PositionEmbeddingType {
|
||||||
@ -174,20 +133,6 @@ impl Module for Dropout {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
|
|
||||||
let (weight, bias) = match (vb.get(size, "weight"), vb.get(size, "bias")) {
|
|
||||||
(Ok(weight), Ok(bias)) => (weight, bias),
|
|
||||||
(Err(err), _) | (_, Err(err)) => {
|
|
||||||
if let (Ok(weight), Ok(bias)) = (vb.get(size, "gamma"), vb.get(size, "beta")) {
|
|
||||||
(weight, bias)
|
|
||||||
} else {
|
|
||||||
return Err(err);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
Ok(LayerNorm::new(weight, bias, eps))
|
|
||||||
}
|
|
||||||
|
|
||||||
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L180
|
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L180
|
||||||
struct BertEmbeddings {
|
struct BertEmbeddings {
|
||||||
word_embeddings: Embedding,
|
word_embeddings: Embedding,
|
||||||
|
@ -124,3 +124,34 @@ impl std::fmt::Debug for QMatMul {
|
|||||||
write!(f, "QMatMul")
|
write!(f, "QMatMul")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct LayerNorm {
|
||||||
|
inner: candle_nn::LayerNorm,
|
||||||
|
span: tracing::Span,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LayerNorm {
|
||||||
|
pub fn new(weight: Tensor, bias: Tensor, eps: f64) -> Self {
|
||||||
|
let inner = candle_nn::LayerNorm::new(weight, bias, eps);
|
||||||
|
let span = tracing::span!(tracing::Level::TRACE, "layer-norm");
|
||||||
|
Self { inner, span }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for LayerNorm {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
|
self.inner.forward(xs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn layer_norm<C: Into<candle_nn::LayerNormConfig>>(
|
||||||
|
size: usize,
|
||||||
|
c: C,
|
||||||
|
vb: VarBuilder,
|
||||||
|
) -> Result<LayerNorm> {
|
||||||
|
let inner = candle_nn::layer_norm(size, c, vb)?;
|
||||||
|
let span = tracing::span!(tracing::Level::TRACE, "layer-norm");
|
||||||
|
Ok(LayerNorm { inner, span })
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user