From d32e8199cd6c8381aa309528675d6d6a88c0f850 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Thu, 17 Aug 2023 10:07:13 +0100 Subject: [PATCH] Layer norm tweaks (#482) * Add some options to make layer-norm more configurable. * Add the rms-norm variant. * Replace the RmsNorm with the shared bits. --- candle-examples/examples/llama/model.rs | 23 +--- candle-examples/examples/llama2-c/model.rs | 42 ++----- .../examples/llama_multiprocess/model.rs | 45 +------- candle-examples/examples/quantized/main.rs | 22 +--- candle-nn/src/layer_norm.rs | 104 ++++++++++++++++-- candle-nn/src/lib.rs | 2 +- candle-wasm-examples/llama2-c/src/model.rs | 44 ++------ 7 files changed, 124 insertions(+), 158 deletions(-) diff --git a/candle-examples/examples/llama/model.rs b/candle-examples/examples/llama/model.rs index 751b5902..e0bb70e7 100644 --- a/candle-examples/examples/llama/model.rs +++ b/candle-examples/examples/llama/model.rs @@ -152,35 +152,20 @@ fn embedding(cfg: &Config, vb: VarBuilder) -> Result { } struct RmsNorm { - scale: Tensor, - eps: f64, + inner: candle_nn::LayerNorm, span: tracing::Span, } impl RmsNorm { fn load(size: usize, eps: f64, vb: VarBuilder) -> Result { let span = tracing::span!(tracing::Level::TRACE, "rms-norm"); - let scale = vb.get(size, "weight")?; - Ok(Self { scale, eps, span }) + let inner = candle_nn::rms_norm(size, eps, vb)?; + Ok(Self { inner, span }) } fn forward(&self, x: &Tensor) -> Result { let _enter = self.span.enter(); - let in_dtype = x.dtype(); - // This is a no-op if x's dtype is already f32. - let x = x.to_dtype(DType::F32)?; - let (b_sz, seq_len, hidden_size) = x.dims3()?; - let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?; - let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?; - let x_normed = (x / (norm_x + self.eps)?.sqrt()?)?; - let size = self.scale.dims1()?; - let scale = self - .scale - .to_dtype(DType::F32)? - .broadcast_as((b_sz, seq_len, size))?; - let x = (scale * x_normed)?; - let x = x.to_dtype(in_dtype)?; - Ok(x) + self.inner.forward(x) } } diff --git a/candle-examples/examples/llama2-c/model.rs b/candle-examples/examples/llama2-c/model.rs index 77900d27..75269665 100644 --- a/candle-examples/examples/llama2-c/model.rs +++ b/candle-examples/examples/llama2-c/model.rs @@ -1,6 +1,6 @@ use candle::{DType, Device, IndexOp, Result, Tensor, D}; use candle_nn::linear_no_bias as linear; -use candle_nn::{embedding, Embedding, Linear, VarBuilder}; +use candle_nn::{embedding, rms_norm, Embedding, LayerNorm, Linear, VarBuilder}; use std::collections::HashMap; use std::sync::{Arc, Mutex}; @@ -94,32 +94,6 @@ fn silu(xs: &Tensor) -> Result { xs / (xs.neg()?.exp()? + 1.0)? } -struct RmsNorm { - scale: Tensor, - eps: f64, -} - -impl RmsNorm { - fn load(size: usize, eps: f64, vb: VarBuilder) -> Result { - let scale = vb.get_or_init(size, "weight", candle_nn::Init::Const(1.))?; - Ok(Self { scale, eps }) - } - - fn forward(&self, x: &Tensor) -> Result { - let (b_sz, seq_len, hidden_size) = x.dims3()?; - let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?; - let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?; - let x_normed = (x / (norm_x + self.eps)?.sqrt()?)?; - let size = self.scale.dims1()?; - let scale = self - .scale - .to_dtype(DType::F32)? - .broadcast_as((b_sz, seq_len, size))?; - let x = (scale * x_normed)?; - Ok(x) - } -} - struct CausalSelfAttention { q_proj: Linear, k_proj: Linear, @@ -262,14 +236,14 @@ impl Mlp { } struct Block { - rms_1: RmsNorm, + rms_1: LayerNorm, attn: CausalSelfAttention, - rms_2: RmsNorm, + rms_2: LayerNorm, mlp: Mlp, } impl Block { - fn new(rms_1: RmsNorm, attn: CausalSelfAttention, rms_2: RmsNorm, mlp: Mlp) -> Self { + fn new(rms_1: LayerNorm, attn: CausalSelfAttention, rms_2: LayerNorm, mlp: Mlp) -> Self { Self { rms_1, attn, @@ -290,9 +264,9 @@ impl Block { fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result { let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg)?; let mlp = Mlp::load(vb.pp("mlp"), cfg)?; - let input_layernorm = RmsNorm::load(cfg.dim, cfg.norm_eps, vb.pp("input_layernorm"))?; + let input_layernorm = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("input_layernorm"))?; let post_attention_layernorm = - RmsNorm::load(cfg.dim, cfg.norm_eps, vb.pp("post_attention_layernorm"))?; + rms_norm(cfg.dim, cfg.norm_eps, vb.pp("post_attention_layernorm"))?; Ok(Self::new( input_layernorm, attn, @@ -305,7 +279,7 @@ impl Block { pub struct Llama { wte: Embedding, blocks: Vec, - ln_f: RmsNorm, + ln_f: LayerNorm, lm_head: Linear, pub config: Config, } @@ -325,7 +299,7 @@ impl Llama { pub fn load(vb: VarBuilder, cache: &Cache, cfg: Config) -> Result { let wte = embedding(cfg.vocab_size, cfg.dim, vb.pp("model.embed_tokens"))?; let lm_head = linear(cfg.dim, cfg.vocab_size, vb.pp("lm_head"))?; - let ln_f = RmsNorm::load(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?; + let ln_f = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?; let blocks: Vec<_> = (0..cfg.n_layers) .map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), cache, &cfg).unwrap()) .collect(); diff --git a/candle-examples/examples/llama_multiprocess/model.rs b/candle-examples/examples/llama_multiprocess/model.rs index 348248f6..ab4e382c 100644 --- a/candle-examples/examples/llama_multiprocess/model.rs +++ b/candle-examples/examples/llama_multiprocess/model.rs @@ -1,6 +1,6 @@ use candle::backend::BackendStorage; use candle::{CpuStorage, CustomOp1, DType, Device, IndexOp, Layout, Result, Shape, Tensor, D}; -use candle_nn::{Embedding, Linear, VarBuilder}; +use candle_nn::{rms_norm, Embedding, LayerNorm, Linear, VarBuilder}; use cudarc::nccl::safe::{Comm, ReduceOp}; use half::f16; use std::rc::Rc; @@ -182,39 +182,6 @@ fn embedding(cfg: &Config, vb: VarBuilder) -> Result { Ok(Embedding::new(embeddings, cfg.hidden_size)) } -struct RmsNorm { - scale: Tensor, -} - -impl RmsNorm { - fn load(size: usize, vb: VarBuilder) -> Result { - let scale = vb.get(size, "weight")?; - Ok(Self::new(scale)) - } - - fn new(scale: Tensor) -> Self { - Self { scale } - } - - fn forward(&self, x: &Tensor) -> Result { - let in_dtype = x.dtype(); - // This is a no-op if x's dtype is already f32. - let x = x.to_dtype(DType::F32)?; - let (b_sz, seq_len, hidden_size) = x.shape().dims3()?; - let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?; - let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?; - let x_normed = (x / (norm_x + 1e-5)?.sqrt()?)?; - let size = self.scale.shape().dims1()?; - let scale = self - .scale - .to_dtype(DType::F32)? - .broadcast_as((b_sz, seq_len, size))?; - let x = (scale * x_normed)?; - let x = x.to_dtype(in_dtype)?; - Ok(x) - } -} - struct CausalSelfAttention { qkv_proj: TensorParallelColumnLinear, o_proj: TensorParallelRowLinear, @@ -369,14 +336,14 @@ impl Mlp { } struct Block { - rms_1: RmsNorm, + rms_1: LayerNorm, attn: CausalSelfAttention, - rms_2: RmsNorm, + rms_2: LayerNorm, mlp: Mlp, } impl Block { - fn new(rms_1: RmsNorm, attn: CausalSelfAttention, rms_2: RmsNorm, mlp: Mlp) -> Self { + fn new(rms_1: LayerNorm, attn: CausalSelfAttention, rms_2: LayerNorm, mlp: Mlp) -> Self { Self { rms_1, attn, @@ -397,9 +364,9 @@ impl Block { fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc) -> Result { let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg, comm.clone())?; let mlp = Mlp::load(vb.pp("mlp"), cfg, comm)?; - let input_layernorm = RmsNorm::load(cfg.hidden_size, vb.pp("input_layernorm"))?; + let input_layernorm = rms_norm(cfg.hidden_size, 1e-5, vb.pp("input_layernorm"))?; let post_attention_layernorm = - RmsNorm::load(cfg.hidden_size, vb.pp("post_attention_layernorm"))?; + rms_norm(cfg.hidden_size, 1e-5, vb.pp("post_attention_layernorm"))?; Ok(Self::new( input_layernorm, attn, diff --git a/candle-examples/examples/quantized/main.rs b/candle-examples/examples/quantized/main.rs index f42d6f0f..94efb03f 100644 --- a/candle-examples/examples/quantized/main.rs +++ b/candle-examples/examples/quantized/main.rs @@ -14,8 +14,7 @@ const MAX_SEQ_LEN: usize = 4096; const DEFAULT_PROMPT: &str = "My favorite theorem is "; struct RmsNorm { - scale: Tensor, - eps: f64, + inner: candle_nn::LayerNorm, span: tracing::Span, } @@ -23,26 +22,13 @@ impl RmsNorm { fn new(scale: QTensor) -> Result { let span = tracing::span!(tracing::Level::TRACE, "rms-norm"); let scale = scale.dequantize(&Device::Cpu)?; - Ok(Self { - scale, - eps: 1e-5, - span, - }) + let inner = candle_nn::LayerNorm::rms_norm(scale, 1e-5); + Ok(Self { inner, span }) } fn forward(&self, x: &Tensor) -> Result { let _enter = self.span.enter(); - let (b_sz, seq_len, hidden_size) = x.dims3()?; - let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?; - let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?; - let x_normed = (x / (norm_x + self.eps)?.sqrt()?)?; - let size = self.scale.dims1()?; - let scale = self - .scale - .to_dtype(DType::F32)? - .broadcast_as((b_sz, seq_len, size))?; - let x = (scale * x_normed)?; - Ok(x) + self.inner.forward(x) } } diff --git a/candle-nn/src/layer_norm.rs b/candle-nn/src/layer_norm.rs index 668f9a4b..f9892a2c 100644 --- a/candle-nn/src/layer_norm.rs +++ b/candle-nn/src/layer_norm.rs @@ -30,17 +30,70 @@ //! [`Layer Normalization`]: https://arxiv.org/abs/1607.06450 use candle::{DType, Result, Tensor}; +#[derive(Debug, Clone, Copy, PartialEq)] +pub struct LayerNormConfig { + pub eps: f64, + /// Whether to remove the mean or not, the default is true and when set to false, this turns + /// this layer into RmsNorm. + pub remove_mean: bool, + pub affine: bool, +} + +impl Default for LayerNormConfig { + fn default() -> Self { + Self { + eps: 1e-5, + remove_mean: true, + affine: true, + } + } +} + +impl From for LayerNormConfig { + fn from(eps: f64) -> Self { + Self { + eps, + remove_mean: true, + affine: true, + } + } +} + // This layer norm version handles both weight and bias so removes the mean. #[derive(Debug)] pub struct LayerNorm { weight: Tensor, - bias: Tensor, + bias: Option, + remove_mean: bool, eps: f64, } impl LayerNorm { pub fn new(weight: Tensor, bias: Tensor, eps: f64) -> Self { - Self { weight, bias, eps } + Self { + weight, + bias: Some(bias), + remove_mean: true, + eps, + } + } + + pub fn new_no_bias(weight: Tensor, eps: f64) -> Self { + Self { + weight, + bias: None, + remove_mean: true, + eps, + } + } + + pub fn rms_norm(weight: Tensor, eps: f64) -> Self { + Self { + weight, + bias: None, + remove_mean: false, + eps, + } } pub fn forward(&self, x: &Tensor) -> Result { @@ -51,20 +104,47 @@ impl LayerNorm { }; let (_bsize, _seq_len, hidden_size) = x.dims3()?; let x = x.to_dtype(internal_dtype)?; - let mean_x = (x.sum_keepdim(2)? / hidden_size as f64)?; - let x = x.broadcast_sub(&mean_x)?; + let x = if self.remove_mean { + let mean_x = (x.sum_keepdim(2)? / hidden_size as f64)?; + x.broadcast_sub(&mean_x)? + } else { + x + }; let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?; let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?; - let x = x_normed - .to_dtype(x_dtype)? - .broadcast_mul(&self.weight)? - .broadcast_add(&self.bias)?; - Ok(x) + let x = x_normed.to_dtype(x_dtype)?.broadcast_mul(&self.weight)?; + match &self.bias { + None => Ok(x), + Some(bias) => x.broadcast_add(bias), + } } } -pub fn layer_norm(size: usize, eps: f64, vb: crate::VarBuilder) -> Result { +pub fn layer_norm>( + size: usize, + config: C, + vb: crate::VarBuilder, +) -> Result { + let config = config.into(); let weight = vb.get_or_init(size, "weight", crate::Init::Const(1.))?; - let bias = vb.get_or_init(size, "bias", crate::Init::Const(0.))?; - Ok(LayerNorm::new(weight, bias, eps)) + let bias = if config.affine { + Some(vb.get_or_init(size, "bias", crate::Init::Const(0.))?) + } else { + None + }; + Ok(LayerNorm { + weight, + bias, + remove_mean: config.remove_mean, + eps: config.eps, + }) +} + +pub fn rms_norm(size: usize, eps: f64, vb: crate::VarBuilder) -> Result { + let config = LayerNormConfig { + eps, + remove_mean: false, + affine: false, + }; + layer_norm(size, config, vb) } diff --git a/candle-nn/src/lib.rs b/candle-nn/src/lib.rs index ae955f56..05464ceb 100644 --- a/candle-nn/src/lib.rs +++ b/candle-nn/src/lib.rs @@ -17,7 +17,7 @@ pub use conv::{conv1d, conv2d, Conv1d, Conv1dConfig, Conv2d, Conv2dConfig}; pub use embedding::{embedding, Embedding}; pub use group_norm::{group_norm, GroupNorm}; pub use init::Init; -pub use layer_norm::{layer_norm, LayerNorm}; +pub use layer_norm::{layer_norm, rms_norm, LayerNorm, LayerNormConfig}; pub use linear::{linear, linear_no_bias, Linear}; pub use optim::{AdamW, ParamsAdamW, SGD}; pub use var_builder::{VarBuilder, VarMap}; diff --git a/candle-wasm-examples/llama2-c/src/model.rs b/candle-wasm-examples/llama2-c/src/model.rs index 3231cabf..d2b787ae 100644 --- a/candle-wasm-examples/llama2-c/src/model.rs +++ b/candle-wasm-examples/llama2-c/src/model.rs @@ -1,5 +1,5 @@ use candle::{DType, Device, IndexOp, Result, Tensor, D}; -use candle_nn::{Embedding, Linear, VarBuilder}; +use candle_nn::{rms_norm, Embedding, LayerNorm, Linear, VarBuilder}; use std::collections::HashMap; use std::sync::{Arc, Mutex}; @@ -71,32 +71,6 @@ fn embedding(cfg: &Config, vb: VarBuilder) -> Result { Ok(Embedding::new(embeddings, cfg.dim)) } -struct RmsNorm { - scale: Tensor, - eps: f64, -} - -impl RmsNorm { - fn load(size: usize, eps: f64, vb: VarBuilder) -> Result { - let scale = vb.get(size, "weight")?; - Ok(Self { scale, eps }) - } - - fn forward(&self, x: &Tensor) -> Result { - let (b_sz, seq_len, hidden_size) = x.dims3()?; - let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?; - let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?; - let x_normed = (x / (norm_x + self.eps)?.sqrt()?)?; - let size = self.scale.dims1()?; - let scale = self - .scale - .to_dtype(DType::F32)? - .broadcast_as((b_sz, seq_len, size))?; - let x = (scale * x_normed)?; - Ok(x) - } -} - struct CausalSelfAttention { q_proj: Linear, k_proj: Linear, @@ -239,14 +213,14 @@ impl Mlp { } struct Block { - rms_1: RmsNorm, + rms_1: LayerNorm, attn: CausalSelfAttention, - rms_2: RmsNorm, + rms_2: LayerNorm, mlp: Mlp, } impl Block { - fn new(rms_1: RmsNorm, attn: CausalSelfAttention, rms_2: RmsNorm, mlp: Mlp) -> Self { + fn new(rms_1: LayerNorm, attn: CausalSelfAttention, rms_2: LayerNorm, mlp: Mlp) -> Self { Self { rms_1, attn, @@ -267,9 +241,9 @@ impl Block { fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result { let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg)?; let mlp = Mlp::load(vb.pp("mlp"), cfg)?; - let input_layernorm = RmsNorm::load(cfg.dim, cfg.norm_eps, vb.pp("input_layernorm"))?; + let input_layernorm = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("input_layernorm"))?; let post_attention_layernorm = - RmsNorm::load(cfg.dim, cfg.norm_eps, vb.pp("post_attention_layernorm"))?; + rms_norm(cfg.dim, cfg.norm_eps, vb.pp("post_attention_layernorm"))?; Ok(Self::new( input_layernorm, attn, @@ -282,12 +256,12 @@ impl Block { pub struct Llama { wte: Embedding, blocks: Vec, - ln_f: RmsNorm, + ln_f: LayerNorm, lm_head: Linear, } impl Llama { - fn new(wte: Embedding, blocks: Vec, ln_f: RmsNorm, lm_head: Linear) -> Self { + fn new(wte: Embedding, blocks: Vec, ln_f: LayerNorm, lm_head: Linear) -> Self { Self { wte, blocks, @@ -311,7 +285,7 @@ impl Llama { pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result { let wte = embedding(cfg, vb.pp("model.embed_tokens"))?; let lm_head = linear(cfg.dim, cfg.vocab_size, vb.pp("lm_head"))?; - let norm = RmsNorm::load(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?; + let norm = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?; let blocks: Vec<_> = (0..cfg.n_layers) .map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), cache, cfg).unwrap()) .collect();