mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Layer norm tweaks (#482)
* Add some options to make layer-norm more configurable. * Add the rms-norm variant. * Replace the RmsNorm with the shared bits.
This commit is contained in:
@ -14,8 +14,7 @@ const MAX_SEQ_LEN: usize = 4096;
|
||||
const DEFAULT_PROMPT: &str = "My favorite theorem is ";
|
||||
|
||||
struct RmsNorm {
|
||||
scale: Tensor,
|
||||
eps: f64,
|
||||
inner: candle_nn::LayerNorm,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
@ -23,26 +22,13 @@ impl RmsNorm {
|
||||
fn new(scale: QTensor) -> Result<Self> {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
|
||||
let scale = scale.dequantize(&Device::Cpu)?;
|
||||
Ok(Self {
|
||||
scale,
|
||||
eps: 1e-5,
|
||||
span,
|
||||
})
|
||||
let inner = candle_nn::LayerNorm::rms_norm(scale, 1e-5);
|
||||
Ok(Self { inner, span })
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let (b_sz, seq_len, hidden_size) = x.dims3()?;
|
||||
let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
|
||||
let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?;
|
||||
let x_normed = (x / (norm_x + self.eps)?.sqrt()?)?;
|
||||
let size = self.scale.dims1()?;
|
||||
let scale = self
|
||||
.scale
|
||||
.to_dtype(DType::F32)?
|
||||
.broadcast_as((b_sz, seq_len, size))?;
|
||||
let x = (scale * x_normed)?;
|
||||
Ok(x)
|
||||
self.inner.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user