mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Layer norm tweaks (#482)
* Add some options to make layer-norm more configurable. * Add the rms-norm variant. * Replace the RmsNorm with the shared bits.
This commit is contained in:
@ -152,35 +152,20 @@ fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> {
|
||||
}
|
||||
|
||||
struct RmsNorm {
|
||||
scale: Tensor,
|
||||
eps: f64,
|
||||
inner: candle_nn::LayerNorm,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl RmsNorm {
|
||||
fn load(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
|
||||
let scale = vb.get(size, "weight")?;
|
||||
Ok(Self { scale, eps, span })
|
||||
let inner = candle_nn::rms_norm(size, eps, vb)?;
|
||||
Ok(Self { inner, span })
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let in_dtype = x.dtype();
|
||||
// This is a no-op if x's dtype is already f32.
|
||||
let x = x.to_dtype(DType::F32)?;
|
||||
let (b_sz, seq_len, hidden_size) = x.dims3()?;
|
||||
let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?;
|
||||
let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?;
|
||||
let x_normed = (x / (norm_x + self.eps)?.sqrt()?)?;
|
||||
let size = self.scale.dims1()?;
|
||||
let scale = self
|
||||
.scale
|
||||
.to_dtype(DType::F32)?
|
||||
.broadcast_as((b_sz, seq_len, size))?;
|
||||
let x = (scale * x_normed)?;
|
||||
let x = x.to_dtype(in_dtype)?;
|
||||
Ok(x)
|
||||
self.inner.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user