Layer norm tweaks (#482)

* Add some options to make layer-norm more configurable.

* Add the rms-norm variant.

* Replace the RmsNorm with the shared bits.
This commit is contained in:
Laurent Mazare
2023-08-17 10:07:13 +01:00
committed by GitHub
parent d99cac3ec3
commit d32e8199cd
7 changed files with 124 additions and 158 deletions

View File

@ -14,8 +14,7 @@ const MAX_SEQ_LEN: usize = 4096;
const DEFAULT_PROMPT: &str = "My favorite theorem is ";
struct RmsNorm {
scale: Tensor,
eps: f64,
inner: candle_nn::LayerNorm,
span: tracing::Span,
}
@ -23,26 +22,13 @@ impl RmsNorm {
fn new(scale: QTensor) -> Result<Self> {
let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
let scale = scale.dequantize(&Device::Cpu)?;
Ok(Self {
scale,
eps: 1e-5,
span,
})
let inner = candle_nn::LayerNorm::rms_norm(scale, 1e-5);
Ok(Self { inner, span })
}
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let (b_sz, seq_len, hidden_size) = x.dims3()?;
let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?;
let x_normed = (x / (norm_x + self.eps)?.sqrt()?)?;
let size = self.scale.dims1()?;
let scale = self
.scale
.to_dtype(DType::F32)?
.broadcast_as((b_sz, seq_len, size))?;
let x = (scale * x_normed)?;
Ok(x)
self.inner.forward(x)
}
}