From 13401df4d141bf568a2c2056411d62060707e79b Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Fri, 18 Aug 2023 08:52:14 +0100 Subject: [PATCH] Add an abstract type for RmsNorm. (#499) --- README.md | 5 +++-- candle-examples/examples/llama/model.rs | 2 +- candle-examples/examples/llama2-c/model.rs | 10 ++++----- .../examples/llama_multiprocess/main.rs | 6 +++-- .../examples/llama_multiprocess/model.rs | 10 ++++----- candle-nn/src/layer_norm.rs | 22 +++++++++++++++++-- candle-nn/src/lib.rs | 2 +- candle-wasm-examples/llama2-c/src/model.rs | 12 +++++----- 8 files changed, 45 insertions(+), 24 deletions(-) diff --git a/README.md b/README.md index 088e78b3..99c1a1b5 100644 --- a/README.md +++ b/README.md @@ -70,11 +70,12 @@ And then head over to - Optimized CPU backend with optional MKL support for x86 and Accelerate for macs. - CUDA backend for efficiently running on GPUs, multiple GPU distribution via NCCL. - WASM support, run your models in a browser. -- Model support out of the box. +- Included models. - LLMs: Llama v1 and v2, Falcon, StarCoder. - - Whisper. + - Whisper (multi-lingual support). - Stable Diffusion. - Serverless (on CPU), small and fast deployments. +- Quantization support using the llama.cpp quantized types. diff --git a/candle-examples/examples/llama/model.rs b/candle-examples/examples/llama/model.rs index e0bb70e7..13eb7390 100644 --- a/candle-examples/examples/llama/model.rs +++ b/candle-examples/examples/llama/model.rs @@ -152,7 +152,7 @@ fn embedding(cfg: &Config, vb: VarBuilder) -> Result { } struct RmsNorm { - inner: candle_nn::LayerNorm, + inner: candle_nn::RmsNorm, span: tracing::Span, } diff --git a/candle-examples/examples/llama2-c/model.rs b/candle-examples/examples/llama2-c/model.rs index 75269665..aae9673a 100644 --- a/candle-examples/examples/llama2-c/model.rs +++ b/candle-examples/examples/llama2-c/model.rs @@ -1,6 +1,6 @@ use candle::{DType, Device, IndexOp, Result, Tensor, D}; use candle_nn::linear_no_bias as linear; -use candle_nn::{embedding, rms_norm, Embedding, LayerNorm, Linear, VarBuilder}; +use candle_nn::{embedding, rms_norm, Embedding, Linear, RmsNorm, VarBuilder}; use std::collections::HashMap; use std::sync::{Arc, Mutex}; @@ -236,14 +236,14 @@ impl Mlp { } struct Block { - rms_1: LayerNorm, + rms_1: RmsNorm, attn: CausalSelfAttention, - rms_2: LayerNorm, + rms_2: RmsNorm, mlp: Mlp, } impl Block { - fn new(rms_1: LayerNorm, attn: CausalSelfAttention, rms_2: LayerNorm, mlp: Mlp) -> Self { + fn new(rms_1: RmsNorm, attn: CausalSelfAttention, rms_2: RmsNorm, mlp: Mlp) -> Self { Self { rms_1, attn, @@ -279,7 +279,7 @@ impl Block { pub struct Llama { wte: Embedding, blocks: Vec, - ln_f: LayerNorm, + ln_f: RmsNorm, lm_head: Linear, pub config: Config, } diff --git a/candle-examples/examples/llama_multiprocess/main.rs b/candle-examples/examples/llama_multiprocess/main.rs index c637a99a..d6d0d14e 100644 --- a/candle-examples/examples/llama_multiprocess/main.rs +++ b/candle-examples/examples/llama_multiprocess/main.rs @@ -231,7 +231,7 @@ fn main() -> Result<()> { "{} token: {} '{}'", index + 1, next_token, - tokenizer.decode(vec![next_token], true).map_err(E::msg)? + tokenizer.decode(&[next_token], true).map_err(E::msg)? ); } } @@ -241,7 +241,9 @@ fn main() -> Result<()> { "{} tokens generated ({} token/s)\n----\n{}\n----", args.sample_len, args.sample_len as f64 / dt.as_secs_f64(), - tokenizer.decode(new_tokens, true).map_err(E::msg)? + tokenizer + .decode(new_tokens.as_slice(), true) + .map_err(E::msg)? ); } Ok(()) diff --git a/candle-examples/examples/llama_multiprocess/model.rs b/candle-examples/examples/llama_multiprocess/model.rs index ad5e4cd2..fa8f9abf 100644 --- a/candle-examples/examples/llama_multiprocess/model.rs +++ b/candle-examples/examples/llama_multiprocess/model.rs @@ -1,6 +1,6 @@ use candle::backend::BackendStorage; use candle::{CpuStorage, CustomOp1, DType, Device, IndexOp, Layout, Result, Shape, Tensor, D}; -use candle_nn::{rms_norm, Embedding, LayerNorm, Linear, VarBuilder}; +use candle_nn::{rms_norm, Embedding, Linear, RmsNorm, VarBuilder}; use cudarc::nccl::safe::{Comm, ReduceOp}; use half::f16; use std::rc::Rc; @@ -336,14 +336,14 @@ impl Mlp { } struct Block { - rms_1: LayerNorm, + rms_1: RmsNorm, attn: CausalSelfAttention, - rms_2: LayerNorm, + rms_2: RmsNorm, mlp: Mlp, } impl Block { - fn new(rms_1: LayerNorm, attn: CausalSelfAttention, rms_2: LayerNorm, mlp: Mlp) -> Self { + fn new(rms_1: RmsNorm, attn: CausalSelfAttention, rms_2: RmsNorm, mlp: Mlp) -> Self { Self { rms_1, attn, @@ -408,7 +408,7 @@ impl Llama { pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc) -> Result { let wte = embedding(cfg, vb.pp("model.embed_tokens"))?; let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?; - let norm = RmsNorm::load(cfg.hidden_size, vb.pp("model.norm"))?; + let norm = rms_norm(cfg.hidden_size, 1e-5, vb.pp("model.norm"))?; let blocks: Vec<_> = (0..cfg.n_layer) .map(|i| { Block::load( diff --git a/candle-nn/src/layer_norm.rs b/candle-nn/src/layer_norm.rs index f9892a2c..17cdef3d 100644 --- a/candle-nn/src/layer_norm.rs +++ b/candle-nn/src/layer_norm.rs @@ -140,11 +140,29 @@ pub fn layer_norm>( }) } -pub fn rms_norm(size: usize, eps: f64, vb: crate::VarBuilder) -> Result { +/// RmsNorm is a specialized version of the LayerNorm module. +#[derive(Debug)] +pub struct RmsNorm(LayerNorm); + +impl RmsNorm { + pub fn new(weight: Tensor, eps: f64) -> Self { + Self(LayerNorm::rms_norm(weight, eps)) + } + + pub fn into_inner(self) -> LayerNorm { + self.0 + } + + pub fn forward(&self, xs: &Tensor) -> Result { + self.0.forward(xs) + } +} + +pub fn rms_norm(size: usize, eps: f64, vb: crate::VarBuilder) -> Result { let config = LayerNormConfig { eps, remove_mean: false, affine: false, }; - layer_norm(size, config, vb) + Ok(RmsNorm(layer_norm(size, config, vb)?)) } diff --git a/candle-nn/src/lib.rs b/candle-nn/src/lib.rs index 05464ceb..c04e8ff4 100644 --- a/candle-nn/src/lib.rs +++ b/candle-nn/src/lib.rs @@ -17,7 +17,7 @@ pub use conv::{conv1d, conv2d, Conv1d, Conv1dConfig, Conv2d, Conv2dConfig}; pub use embedding::{embedding, Embedding}; pub use group_norm::{group_norm, GroupNorm}; pub use init::Init; -pub use layer_norm::{layer_norm, rms_norm, LayerNorm, LayerNormConfig}; +pub use layer_norm::{layer_norm, rms_norm, LayerNorm, LayerNormConfig, RmsNorm}; pub use linear::{linear, linear_no_bias, Linear}; pub use optim::{AdamW, ParamsAdamW, SGD}; pub use var_builder::{VarBuilder, VarMap}; diff --git a/candle-wasm-examples/llama2-c/src/model.rs b/candle-wasm-examples/llama2-c/src/model.rs index d2b787ae..2c867793 100644 --- a/candle-wasm-examples/llama2-c/src/model.rs +++ b/candle-wasm-examples/llama2-c/src/model.rs @@ -1,5 +1,5 @@ use candle::{DType, Device, IndexOp, Result, Tensor, D}; -use candle_nn::{rms_norm, Embedding, LayerNorm, Linear, VarBuilder}; +use candle_nn::{rms_norm, Embedding, Linear, RmsNorm, VarBuilder}; use std::collections::HashMap; use std::sync::{Arc, Mutex}; @@ -213,14 +213,14 @@ impl Mlp { } struct Block { - rms_1: LayerNorm, + rms_1: RmsNorm, attn: CausalSelfAttention, - rms_2: LayerNorm, + rms_2: RmsNorm, mlp: Mlp, } impl Block { - fn new(rms_1: LayerNorm, attn: CausalSelfAttention, rms_2: LayerNorm, mlp: Mlp) -> Self { + fn new(rms_1: RmsNorm, attn: CausalSelfAttention, rms_2: RmsNorm, mlp: Mlp) -> Self { Self { rms_1, attn, @@ -256,12 +256,12 @@ impl Block { pub struct Llama { wte: Embedding, blocks: Vec, - ln_f: LayerNorm, + ln_f: RmsNorm, lm_head: Linear, } impl Llama { - fn new(wte: Embedding, blocks: Vec, ln_f: LayerNorm, lm_head: Linear) -> Self { + fn new(wte: Embedding, blocks: Vec, ln_f: RmsNorm, lm_head: Linear) -> Self { Self { wte, blocks,