Layer norm tweaks (#482)

* Add some options to make layer-norm more configurable.

* Add the rms-norm variant.

* Replace the RmsNorm with the shared bits.
This commit is contained in:
Laurent Mazare
2023-08-17 10:07:13 +01:00
committed by GitHub
parent d99cac3ec3
commit d32e8199cd
7 changed files with 124 additions and 158 deletions

View File

@ -152,35 +152,20 @@ fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> {
} }
struct RmsNorm { struct RmsNorm {
scale: Tensor, inner: candle_nn::LayerNorm,
eps: f64,
span: tracing::Span, span: tracing::Span,
} }
impl RmsNorm { impl RmsNorm {
fn load(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> { fn load(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
let span = tracing::span!(tracing::Level::TRACE, "rms-norm"); let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
let scale = vb.get(size, "weight")?; let inner = candle_nn::rms_norm(size, eps, vb)?;
Ok(Self { scale, eps, span }) Ok(Self { inner, span })
} }
fn forward(&self, x: &Tensor) -> Result<Tensor> { fn forward(&self, x: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter(); let _enter = self.span.enter();
let in_dtype = x.dtype(); self.inner.forward(x)
// This is a no-op if x's dtype is already f32.
let x = x.to_dtype(DType::F32)?;
let (b_sz, seq_len, hidden_size) = x.dims3()?;
let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?;
let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?;
let x_normed = (x / (norm_x + self.eps)?.sqrt()?)?;
let size = self.scale.dims1()?;
let scale = self
.scale
.to_dtype(DType::F32)?
.broadcast_as((b_sz, seq_len, size))?;
let x = (scale * x_normed)?;
let x = x.to_dtype(in_dtype)?;
Ok(x)
} }
} }

View File

@ -1,6 +1,6 @@
use candle::{DType, Device, IndexOp, Result, Tensor, D}; use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::linear_no_bias as linear; use candle_nn::linear_no_bias as linear;
use candle_nn::{embedding, Embedding, Linear, VarBuilder}; use candle_nn::{embedding, rms_norm, Embedding, LayerNorm, Linear, VarBuilder};
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
@ -94,32 +94,6 @@ fn silu(xs: &Tensor) -> Result<Tensor> {
xs / (xs.neg()?.exp()? + 1.0)? xs / (xs.neg()?.exp()? + 1.0)?
} }
struct RmsNorm {
scale: Tensor,
eps: f64,
}
impl RmsNorm {
fn load(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
let scale = vb.get_or_init(size, "weight", candle_nn::Init::Const(1.))?;
Ok(Self { scale, eps })
}
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let (b_sz, seq_len, hidden_size) = x.dims3()?;
let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?;
let x_normed = (x / (norm_x + self.eps)?.sqrt()?)?;
let size = self.scale.dims1()?;
let scale = self
.scale
.to_dtype(DType::F32)?
.broadcast_as((b_sz, seq_len, size))?;
let x = (scale * x_normed)?;
Ok(x)
}
}
struct CausalSelfAttention { struct CausalSelfAttention {
q_proj: Linear, q_proj: Linear,
k_proj: Linear, k_proj: Linear,
@ -262,14 +236,14 @@ impl Mlp {
} }
struct Block { struct Block {
rms_1: RmsNorm, rms_1: LayerNorm,
attn: CausalSelfAttention, attn: CausalSelfAttention,
rms_2: RmsNorm, rms_2: LayerNorm,
mlp: Mlp, mlp: Mlp,
} }
impl Block { impl Block {
fn new(rms_1: RmsNorm, attn: CausalSelfAttention, rms_2: RmsNorm, mlp: Mlp) -> Self { fn new(rms_1: LayerNorm, attn: CausalSelfAttention, rms_2: LayerNorm, mlp: Mlp) -> Self {
Self { Self {
rms_1, rms_1,
attn, attn,
@ -290,9 +264,9 @@ impl Block {
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> { fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg)?; let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg)?;
let mlp = Mlp::load(vb.pp("mlp"), cfg)?; let mlp = Mlp::load(vb.pp("mlp"), cfg)?;
let input_layernorm = RmsNorm::load(cfg.dim, cfg.norm_eps, vb.pp("input_layernorm"))?; let input_layernorm = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("input_layernorm"))?;
let post_attention_layernorm = let post_attention_layernorm =
RmsNorm::load(cfg.dim, cfg.norm_eps, vb.pp("post_attention_layernorm"))?; rms_norm(cfg.dim, cfg.norm_eps, vb.pp("post_attention_layernorm"))?;
Ok(Self::new( Ok(Self::new(
input_layernorm, input_layernorm,
attn, attn,
@ -305,7 +279,7 @@ impl Block {
pub struct Llama { pub struct Llama {
wte: Embedding, wte: Embedding,
blocks: Vec<Block>, blocks: Vec<Block>,
ln_f: RmsNorm, ln_f: LayerNorm,
lm_head: Linear, lm_head: Linear,
pub config: Config, pub config: Config,
} }
@ -325,7 +299,7 @@ impl Llama {
pub fn load(vb: VarBuilder, cache: &Cache, cfg: Config) -> Result<Self> { pub fn load(vb: VarBuilder, cache: &Cache, cfg: Config) -> Result<Self> {
let wte = embedding(cfg.vocab_size, cfg.dim, vb.pp("model.embed_tokens"))?; let wte = embedding(cfg.vocab_size, cfg.dim, vb.pp("model.embed_tokens"))?;
let lm_head = linear(cfg.dim, cfg.vocab_size, vb.pp("lm_head"))?; let lm_head = linear(cfg.dim, cfg.vocab_size, vb.pp("lm_head"))?;
let ln_f = RmsNorm::load(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?; let ln_f = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?;
let blocks: Vec<_> = (0..cfg.n_layers) let blocks: Vec<_> = (0..cfg.n_layers)
.map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), cache, &cfg).unwrap()) .map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), cache, &cfg).unwrap())
.collect(); .collect();

View File

@ -1,6 +1,6 @@
use candle::backend::BackendStorage; use candle::backend::BackendStorage;
use candle::{CpuStorage, CustomOp1, DType, Device, IndexOp, Layout, Result, Shape, Tensor, D}; use candle::{CpuStorage, CustomOp1, DType, Device, IndexOp, Layout, Result, Shape, Tensor, D};
use candle_nn::{Embedding, Linear, VarBuilder}; use candle_nn::{rms_norm, Embedding, LayerNorm, Linear, VarBuilder};
use cudarc::nccl::safe::{Comm, ReduceOp}; use cudarc::nccl::safe::{Comm, ReduceOp};
use half::f16; use half::f16;
use std::rc::Rc; use std::rc::Rc;
@ -182,39 +182,6 @@ fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> {
Ok(Embedding::new(embeddings, cfg.hidden_size)) Ok(Embedding::new(embeddings, cfg.hidden_size))
} }
struct RmsNorm {
scale: Tensor,
}
impl RmsNorm {
fn load(size: usize, vb: VarBuilder) -> Result<Self> {
let scale = vb.get(size, "weight")?;
Ok(Self::new(scale))
}
fn new(scale: Tensor) -> Self {
Self { scale }
}
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let in_dtype = x.dtype();
// This is a no-op if x's dtype is already f32.
let x = x.to_dtype(DType::F32)?;
let (b_sz, seq_len, hidden_size) = x.shape().dims3()?;
let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?;
let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?;
let x_normed = (x / (norm_x + 1e-5)?.sqrt()?)?;
let size = self.scale.shape().dims1()?;
let scale = self
.scale
.to_dtype(DType::F32)?
.broadcast_as((b_sz, seq_len, size))?;
let x = (scale * x_normed)?;
let x = x.to_dtype(in_dtype)?;
Ok(x)
}
}
struct CausalSelfAttention { struct CausalSelfAttention {
qkv_proj: TensorParallelColumnLinear, qkv_proj: TensorParallelColumnLinear,
o_proj: TensorParallelRowLinear, o_proj: TensorParallelRowLinear,
@ -369,14 +336,14 @@ impl Mlp {
} }
struct Block { struct Block {
rms_1: RmsNorm, rms_1: LayerNorm,
attn: CausalSelfAttention, attn: CausalSelfAttention,
rms_2: RmsNorm, rms_2: LayerNorm,
mlp: Mlp, mlp: Mlp,
} }
impl Block { impl Block {
fn new(rms_1: RmsNorm, attn: CausalSelfAttention, rms_2: RmsNorm, mlp: Mlp) -> Self { fn new(rms_1: LayerNorm, attn: CausalSelfAttention, rms_2: LayerNorm, mlp: Mlp) -> Self {
Self { Self {
rms_1, rms_1,
attn, attn,
@ -397,9 +364,9 @@ impl Block {
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc<Comm>) -> Result<Self> { fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc<Comm>) -> Result<Self> {
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg, comm.clone())?; let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg, comm.clone())?;
let mlp = Mlp::load(vb.pp("mlp"), cfg, comm)?; let mlp = Mlp::load(vb.pp("mlp"), cfg, comm)?;
let input_layernorm = RmsNorm::load(cfg.hidden_size, vb.pp("input_layernorm"))?; let input_layernorm = rms_norm(cfg.hidden_size, 1e-5, vb.pp("input_layernorm"))?;
let post_attention_layernorm = let post_attention_layernorm =
RmsNorm::load(cfg.hidden_size, vb.pp("post_attention_layernorm"))?; rms_norm(cfg.hidden_size, 1e-5, vb.pp("post_attention_layernorm"))?;
Ok(Self::new( Ok(Self::new(
input_layernorm, input_layernorm,
attn, attn,

View File

@ -14,8 +14,7 @@ const MAX_SEQ_LEN: usize = 4096;
const DEFAULT_PROMPT: &str = "My favorite theorem is "; const DEFAULT_PROMPT: &str = "My favorite theorem is ";
struct RmsNorm { struct RmsNorm {
scale: Tensor, inner: candle_nn::LayerNorm,
eps: f64,
span: tracing::Span, span: tracing::Span,
} }
@ -23,26 +22,13 @@ impl RmsNorm {
fn new(scale: QTensor) -> Result<Self> { fn new(scale: QTensor) -> Result<Self> {
let span = tracing::span!(tracing::Level::TRACE, "rms-norm"); let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
let scale = scale.dequantize(&Device::Cpu)?; let scale = scale.dequantize(&Device::Cpu)?;
Ok(Self { let inner = candle_nn::LayerNorm::rms_norm(scale, 1e-5);
scale, Ok(Self { inner, span })
eps: 1e-5,
span,
})
} }
fn forward(&self, x: &Tensor) -> Result<Tensor> { fn forward(&self, x: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter(); let _enter = self.span.enter();
let (b_sz, seq_len, hidden_size) = x.dims3()?; self.inner.forward(x)
let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?;
let x_normed = (x / (norm_x + self.eps)?.sqrt()?)?;
let size = self.scale.dims1()?;
let scale = self
.scale
.to_dtype(DType::F32)?
.broadcast_as((b_sz, seq_len, size))?;
let x = (scale * x_normed)?;
Ok(x)
} }
} }

View File

@ -30,17 +30,70 @@
//! [`Layer Normalization`]: https://arxiv.org/abs/1607.06450 //! [`Layer Normalization`]: https://arxiv.org/abs/1607.06450
use candle::{DType, Result, Tensor}; use candle::{DType, Result, Tensor};
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct LayerNormConfig {
pub eps: f64,
/// Whether to remove the mean or not, the default is true and when set to false, this turns
/// this layer into RmsNorm.
pub remove_mean: bool,
pub affine: bool,
}
impl Default for LayerNormConfig {
fn default() -> Self {
Self {
eps: 1e-5,
remove_mean: true,
affine: true,
}
}
}
impl From<f64> for LayerNormConfig {
fn from(eps: f64) -> Self {
Self {
eps,
remove_mean: true,
affine: true,
}
}
}
// This layer norm version handles both weight and bias so removes the mean. // This layer norm version handles both weight and bias so removes the mean.
#[derive(Debug)] #[derive(Debug)]
pub struct LayerNorm { pub struct LayerNorm {
weight: Tensor, weight: Tensor,
bias: Tensor, bias: Option<Tensor>,
remove_mean: bool,
eps: f64, eps: f64,
} }
impl LayerNorm { impl LayerNorm {
pub fn new(weight: Tensor, bias: Tensor, eps: f64) -> Self { pub fn new(weight: Tensor, bias: Tensor, eps: f64) -> Self {
Self { weight, bias, eps } Self {
weight,
bias: Some(bias),
remove_mean: true,
eps,
}
}
pub fn new_no_bias(weight: Tensor, eps: f64) -> Self {
Self {
weight,
bias: None,
remove_mean: true,
eps,
}
}
pub fn rms_norm(weight: Tensor, eps: f64) -> Self {
Self {
weight,
bias: None,
remove_mean: false,
eps,
}
} }
pub fn forward(&self, x: &Tensor) -> Result<Tensor> { pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
@ -51,20 +104,47 @@ impl LayerNorm {
}; };
let (_bsize, _seq_len, hidden_size) = x.dims3()?; let (_bsize, _seq_len, hidden_size) = x.dims3()?;
let x = x.to_dtype(internal_dtype)?; let x = x.to_dtype(internal_dtype)?;
let x = if self.remove_mean {
let mean_x = (x.sum_keepdim(2)? / hidden_size as f64)?; let mean_x = (x.sum_keepdim(2)? / hidden_size as f64)?;
let x = x.broadcast_sub(&mean_x)?; x.broadcast_sub(&mean_x)?
} else {
x
};
let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?; let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?;
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?; let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
let x = x_normed let x = x_normed.to_dtype(x_dtype)?.broadcast_mul(&self.weight)?;
.to_dtype(x_dtype)? match &self.bias {
.broadcast_mul(&self.weight)? None => Ok(x),
.broadcast_add(&self.bias)?; Some(bias) => x.broadcast_add(bias),
Ok(x) }
} }
} }
pub fn layer_norm(size: usize, eps: f64, vb: crate::VarBuilder) -> Result<LayerNorm> { pub fn layer_norm<C: Into<LayerNormConfig>>(
size: usize,
config: C,
vb: crate::VarBuilder,
) -> Result<LayerNorm> {
let config = config.into();
let weight = vb.get_or_init(size, "weight", crate::Init::Const(1.))?; let weight = vb.get_or_init(size, "weight", crate::Init::Const(1.))?;
let bias = vb.get_or_init(size, "bias", crate::Init::Const(0.))?; let bias = if config.affine {
Ok(LayerNorm::new(weight, bias, eps)) Some(vb.get_or_init(size, "bias", crate::Init::Const(0.))?)
} else {
None
};
Ok(LayerNorm {
weight,
bias,
remove_mean: config.remove_mean,
eps: config.eps,
})
}
pub fn rms_norm(size: usize, eps: f64, vb: crate::VarBuilder) -> Result<LayerNorm> {
let config = LayerNormConfig {
eps,
remove_mean: false,
affine: false,
};
layer_norm(size, config, vb)
} }

View File

@ -17,7 +17,7 @@ pub use conv::{conv1d, conv2d, Conv1d, Conv1dConfig, Conv2d, Conv2dConfig};
pub use embedding::{embedding, Embedding}; pub use embedding::{embedding, Embedding};
pub use group_norm::{group_norm, GroupNorm}; pub use group_norm::{group_norm, GroupNorm};
pub use init::Init; pub use init::Init;
pub use layer_norm::{layer_norm, LayerNorm}; pub use layer_norm::{layer_norm, rms_norm, LayerNorm, LayerNormConfig};
pub use linear::{linear, linear_no_bias, Linear}; pub use linear::{linear, linear_no_bias, Linear};
pub use optim::{AdamW, ParamsAdamW, SGD}; pub use optim::{AdamW, ParamsAdamW, SGD};
pub use var_builder::{VarBuilder, VarMap}; pub use var_builder::{VarBuilder, VarMap};

View File

@ -1,5 +1,5 @@
use candle::{DType, Device, IndexOp, Result, Tensor, D}; use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::{Embedding, Linear, VarBuilder}; use candle_nn::{rms_norm, Embedding, LayerNorm, Linear, VarBuilder};
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
@ -71,32 +71,6 @@ fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> {
Ok(Embedding::new(embeddings, cfg.dim)) Ok(Embedding::new(embeddings, cfg.dim))
} }
struct RmsNorm {
scale: Tensor,
eps: f64,
}
impl RmsNorm {
fn load(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
let scale = vb.get(size, "weight")?;
Ok(Self { scale, eps })
}
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let (b_sz, seq_len, hidden_size) = x.dims3()?;
let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?;
let x_normed = (x / (norm_x + self.eps)?.sqrt()?)?;
let size = self.scale.dims1()?;
let scale = self
.scale
.to_dtype(DType::F32)?
.broadcast_as((b_sz, seq_len, size))?;
let x = (scale * x_normed)?;
Ok(x)
}
}
struct CausalSelfAttention { struct CausalSelfAttention {
q_proj: Linear, q_proj: Linear,
k_proj: Linear, k_proj: Linear,
@ -239,14 +213,14 @@ impl Mlp {
} }
struct Block { struct Block {
rms_1: RmsNorm, rms_1: LayerNorm,
attn: CausalSelfAttention, attn: CausalSelfAttention,
rms_2: RmsNorm, rms_2: LayerNorm,
mlp: Mlp, mlp: Mlp,
} }
impl Block { impl Block {
fn new(rms_1: RmsNorm, attn: CausalSelfAttention, rms_2: RmsNorm, mlp: Mlp) -> Self { fn new(rms_1: LayerNorm, attn: CausalSelfAttention, rms_2: LayerNorm, mlp: Mlp) -> Self {
Self { Self {
rms_1, rms_1,
attn, attn,
@ -267,9 +241,9 @@ impl Block {
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> { fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg)?; let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg)?;
let mlp = Mlp::load(vb.pp("mlp"), cfg)?; let mlp = Mlp::load(vb.pp("mlp"), cfg)?;
let input_layernorm = RmsNorm::load(cfg.dim, cfg.norm_eps, vb.pp("input_layernorm"))?; let input_layernorm = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("input_layernorm"))?;
let post_attention_layernorm = let post_attention_layernorm =
RmsNorm::load(cfg.dim, cfg.norm_eps, vb.pp("post_attention_layernorm"))?; rms_norm(cfg.dim, cfg.norm_eps, vb.pp("post_attention_layernorm"))?;
Ok(Self::new( Ok(Self::new(
input_layernorm, input_layernorm,
attn, attn,
@ -282,12 +256,12 @@ impl Block {
pub struct Llama { pub struct Llama {
wte: Embedding, wte: Embedding,
blocks: Vec<Block>, blocks: Vec<Block>,
ln_f: RmsNorm, ln_f: LayerNorm,
lm_head: Linear, lm_head: Linear,
} }
impl Llama { impl Llama {
fn new(wte: Embedding, blocks: Vec<Block>, ln_f: RmsNorm, lm_head: Linear) -> Self { fn new(wte: Embedding, blocks: Vec<Block>, ln_f: LayerNorm, lm_head: Linear) -> Self {
Self { Self {
wte, wte,
blocks, blocks,
@ -311,7 +285,7 @@ impl Llama {
pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> { pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
let wte = embedding(cfg, vb.pp("model.embed_tokens"))?; let wte = embedding(cfg, vb.pp("model.embed_tokens"))?;
let lm_head = linear(cfg.dim, cfg.vocab_size, vb.pp("lm_head"))?; let lm_head = linear(cfg.dim, cfg.vocab_size, vb.pp("lm_head"))?;
let norm = RmsNorm::load(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?; let norm = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?;
let blocks: Vec<_> = (0..cfg.n_layers) let blocks: Vec<_> = (0..cfg.n_layers)
.map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), cache, cfg).unwrap()) .map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), cache, cfg).unwrap())
.collect(); .collect();