Layer norm tweaks (#482)

* Add some options to make layer-norm more configurable.

* Add the rms-norm variant.

* Replace the RmsNorm with the shared bits.
This commit is contained in:
Laurent Mazare
2023-08-17 10:07:13 +01:00
committed by GitHub
parent d99cac3ec3
commit d32e8199cd
7 changed files with 124 additions and 158 deletions

View File

@ -1,6 +1,6 @@
use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::linear_no_bias as linear;
use candle_nn::{embedding, Embedding, Linear, VarBuilder};
use candle_nn::{embedding, rms_norm, Embedding, LayerNorm, Linear, VarBuilder};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
@ -94,32 +94,6 @@ fn silu(xs: &Tensor) -> Result<Tensor> {
xs / (xs.neg()?.exp()? + 1.0)?
}
struct RmsNorm {
scale: Tensor,
eps: f64,
}
impl RmsNorm {
fn load(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
let scale = vb.get_or_init(size, "weight", candle_nn::Init::Const(1.))?;
Ok(Self { scale, eps })
}
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let (b_sz, seq_len, hidden_size) = x.dims3()?;
let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?;
let x_normed = (x / (norm_x + self.eps)?.sqrt()?)?;
let size = self.scale.dims1()?;
let scale = self
.scale
.to_dtype(DType::F32)?
.broadcast_as((b_sz, seq_len, size))?;
let x = (scale * x_normed)?;
Ok(x)
}
}
struct CausalSelfAttention {
q_proj: Linear,
k_proj: Linear,
@ -262,14 +236,14 @@ impl Mlp {
}
struct Block {
rms_1: RmsNorm,
rms_1: LayerNorm,
attn: CausalSelfAttention,
rms_2: RmsNorm,
rms_2: LayerNorm,
mlp: Mlp,
}
impl Block {
fn new(rms_1: RmsNorm, attn: CausalSelfAttention, rms_2: RmsNorm, mlp: Mlp) -> Self {
fn new(rms_1: LayerNorm, attn: CausalSelfAttention, rms_2: LayerNorm, mlp: Mlp) -> Self {
Self {
rms_1,
attn,
@ -290,9 +264,9 @@ impl Block {
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg)?;
let mlp = Mlp::load(vb.pp("mlp"), cfg)?;
let input_layernorm = RmsNorm::load(cfg.dim, cfg.norm_eps, vb.pp("input_layernorm"))?;
let input_layernorm = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("input_layernorm"))?;
let post_attention_layernorm =
RmsNorm::load(cfg.dim, cfg.norm_eps, vb.pp("post_attention_layernorm"))?;
rms_norm(cfg.dim, cfg.norm_eps, vb.pp("post_attention_layernorm"))?;
Ok(Self::new(
input_layernorm,
attn,
@ -305,7 +279,7 @@ impl Block {
pub struct Llama {
wte: Embedding,
blocks: Vec<Block>,
ln_f: RmsNorm,
ln_f: LayerNorm,
lm_head: Linear,
pub config: Config,
}
@ -325,7 +299,7 @@ impl Llama {
pub fn load(vb: VarBuilder, cache: &Cache, cfg: Config) -> Result<Self> {
let wte = embedding(cfg.vocab_size, cfg.dim, vb.pp("model.embed_tokens"))?;
let lm_head = linear(cfg.dim, cfg.vocab_size, vb.pp("lm_head"))?;
let ln_f = RmsNorm::load(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?;
let ln_f = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?;
let blocks: Vec<_> = (0..cfg.n_layers)
.map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), cache, &cfg).unwrap())
.collect();