mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Add an abstract type for RmsNorm. (#499)
This commit is contained in:
@ -70,11 +70,12 @@ And then head over to
|
||||
- Optimized CPU backend with optional MKL support for x86 and Accelerate for macs.
|
||||
- CUDA backend for efficiently running on GPUs, multiple GPU distribution via NCCL.
|
||||
- WASM support, run your models in a browser.
|
||||
- Model support out of the box.
|
||||
- Included models.
|
||||
- LLMs: Llama v1 and v2, Falcon, StarCoder.
|
||||
- Whisper.
|
||||
- Whisper (multi-lingual support).
|
||||
- Stable Diffusion.
|
||||
- Serverless (on CPU), small and fast deployments.
|
||||
- Quantization support using the llama.cpp quantized types.
|
||||
|
||||
<!--- ANCHOR_END: features --->
|
||||
|
||||
|
@ -152,7 +152,7 @@ fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> {
|
||||
}
|
||||
|
||||
struct RmsNorm {
|
||||
inner: candle_nn::LayerNorm,
|
||||
inner: candle_nn::RmsNorm,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::linear_no_bias as linear;
|
||||
use candle_nn::{embedding, rms_norm, Embedding, LayerNorm, Linear, VarBuilder};
|
||||
use candle_nn::{embedding, rms_norm, Embedding, Linear, RmsNorm, VarBuilder};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
@ -236,14 +236,14 @@ impl Mlp {
|
||||
}
|
||||
|
||||
struct Block {
|
||||
rms_1: LayerNorm,
|
||||
rms_1: RmsNorm,
|
||||
attn: CausalSelfAttention,
|
||||
rms_2: LayerNorm,
|
||||
rms_2: RmsNorm,
|
||||
mlp: Mlp,
|
||||
}
|
||||
|
||||
impl Block {
|
||||
fn new(rms_1: LayerNorm, attn: CausalSelfAttention, rms_2: LayerNorm, mlp: Mlp) -> Self {
|
||||
fn new(rms_1: RmsNorm, attn: CausalSelfAttention, rms_2: RmsNorm, mlp: Mlp) -> Self {
|
||||
Self {
|
||||
rms_1,
|
||||
attn,
|
||||
@ -279,7 +279,7 @@ impl Block {
|
||||
pub struct Llama {
|
||||
wte: Embedding,
|
||||
blocks: Vec<Block>,
|
||||
ln_f: LayerNorm,
|
||||
ln_f: RmsNorm,
|
||||
lm_head: Linear,
|
||||
pub config: Config,
|
||||
}
|
||||
|
@ -231,7 +231,7 @@ fn main() -> Result<()> {
|
||||
"{} token: {} '{}'",
|
||||
index + 1,
|
||||
next_token,
|
||||
tokenizer.decode(vec![next_token], true).map_err(E::msg)?
|
||||
tokenizer.decode(&[next_token], true).map_err(E::msg)?
|
||||
);
|
||||
}
|
||||
}
|
||||
@ -241,7 +241,9 @@ fn main() -> Result<()> {
|
||||
"{} tokens generated ({} token/s)\n----\n{}\n----",
|
||||
args.sample_len,
|
||||
args.sample_len as f64 / dt.as_secs_f64(),
|
||||
tokenizer.decode(new_tokens, true).map_err(E::msg)?
|
||||
tokenizer
|
||||
.decode(new_tokens.as_slice(), true)
|
||||
.map_err(E::msg)?
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
|
@ -1,6 +1,6 @@
|
||||
use candle::backend::BackendStorage;
|
||||
use candle::{CpuStorage, CustomOp1, DType, Device, IndexOp, Layout, Result, Shape, Tensor, D};
|
||||
use candle_nn::{rms_norm, Embedding, LayerNorm, Linear, VarBuilder};
|
||||
use candle_nn::{rms_norm, Embedding, Linear, RmsNorm, VarBuilder};
|
||||
use cudarc::nccl::safe::{Comm, ReduceOp};
|
||||
use half::f16;
|
||||
use std::rc::Rc;
|
||||
@ -336,14 +336,14 @@ impl Mlp {
|
||||
}
|
||||
|
||||
struct Block {
|
||||
rms_1: LayerNorm,
|
||||
rms_1: RmsNorm,
|
||||
attn: CausalSelfAttention,
|
||||
rms_2: LayerNorm,
|
||||
rms_2: RmsNorm,
|
||||
mlp: Mlp,
|
||||
}
|
||||
|
||||
impl Block {
|
||||
fn new(rms_1: LayerNorm, attn: CausalSelfAttention, rms_2: LayerNorm, mlp: Mlp) -> Self {
|
||||
fn new(rms_1: RmsNorm, attn: CausalSelfAttention, rms_2: RmsNorm, mlp: Mlp) -> Self {
|
||||
Self {
|
||||
rms_1,
|
||||
attn,
|
||||
@ -408,7 +408,7 @@ impl Llama {
|
||||
pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc<Comm>) -> Result<Self> {
|
||||
let wte = embedding(cfg, vb.pp("model.embed_tokens"))?;
|
||||
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
|
||||
let norm = RmsNorm::load(cfg.hidden_size, vb.pp("model.norm"))?;
|
||||
let norm = rms_norm(cfg.hidden_size, 1e-5, vb.pp("model.norm"))?;
|
||||
let blocks: Vec<_> = (0..cfg.n_layer)
|
||||
.map(|i| {
|
||||
Block::load(
|
||||
|
@ -140,11 +140,29 @@ pub fn layer_norm<C: Into<LayerNormConfig>>(
|
||||
})
|
||||
}
|
||||
|
||||
pub fn rms_norm(size: usize, eps: f64, vb: crate::VarBuilder) -> Result<LayerNorm> {
|
||||
/// RmsNorm is a specialized version of the LayerNorm module.
|
||||
#[derive(Debug)]
|
||||
pub struct RmsNorm(LayerNorm);
|
||||
|
||||
impl RmsNorm {
|
||||
pub fn new(weight: Tensor, eps: f64) -> Self {
|
||||
Self(LayerNorm::rms_norm(weight, eps))
|
||||
}
|
||||
|
||||
pub fn into_inner(self) -> LayerNorm {
|
||||
self.0
|
||||
}
|
||||
|
||||
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
self.0.forward(xs)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn rms_norm(size: usize, eps: f64, vb: crate::VarBuilder) -> Result<RmsNorm> {
|
||||
let config = LayerNormConfig {
|
||||
eps,
|
||||
remove_mean: false,
|
||||
affine: false,
|
||||
};
|
||||
layer_norm(size, config, vb)
|
||||
Ok(RmsNorm(layer_norm(size, config, vb)?))
|
||||
}
|
||||
|
@ -17,7 +17,7 @@ pub use conv::{conv1d, conv2d, Conv1d, Conv1dConfig, Conv2d, Conv2dConfig};
|
||||
pub use embedding::{embedding, Embedding};
|
||||
pub use group_norm::{group_norm, GroupNorm};
|
||||
pub use init::Init;
|
||||
pub use layer_norm::{layer_norm, rms_norm, LayerNorm, LayerNormConfig};
|
||||
pub use layer_norm::{layer_norm, rms_norm, LayerNorm, LayerNormConfig, RmsNorm};
|
||||
pub use linear::{linear, linear_no_bias, Linear};
|
||||
pub use optim::{AdamW, ParamsAdamW, SGD};
|
||||
pub use var_builder::{VarBuilder, VarMap};
|
||||
|
@ -1,5 +1,5 @@
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{rms_norm, Embedding, LayerNorm, Linear, VarBuilder};
|
||||
use candle_nn::{rms_norm, Embedding, Linear, RmsNorm, VarBuilder};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
@ -213,14 +213,14 @@ impl Mlp {
|
||||
}
|
||||
|
||||
struct Block {
|
||||
rms_1: LayerNorm,
|
||||
rms_1: RmsNorm,
|
||||
attn: CausalSelfAttention,
|
||||
rms_2: LayerNorm,
|
||||
rms_2: RmsNorm,
|
||||
mlp: Mlp,
|
||||
}
|
||||
|
||||
impl Block {
|
||||
fn new(rms_1: LayerNorm, attn: CausalSelfAttention, rms_2: LayerNorm, mlp: Mlp) -> Self {
|
||||
fn new(rms_1: RmsNorm, attn: CausalSelfAttention, rms_2: RmsNorm, mlp: Mlp) -> Self {
|
||||
Self {
|
||||
rms_1,
|
||||
attn,
|
||||
@ -256,12 +256,12 @@ impl Block {
|
||||
pub struct Llama {
|
||||
wte: Embedding,
|
||||
blocks: Vec<Block>,
|
||||
ln_f: LayerNorm,
|
||||
ln_f: RmsNorm,
|
||||
lm_head: Linear,
|
||||
}
|
||||
|
||||
impl Llama {
|
||||
fn new(wte: Embedding, blocks: Vec<Block>, ln_f: LayerNorm, lm_head: Linear) -> Self {
|
||||
fn new(wte: Embedding, blocks: Vec<Block>, ln_f: RmsNorm, lm_head: Linear) -> Self {
|
||||
Self {
|
||||
wte,
|
||||
blocks,
|
||||
|
Reference in New Issue
Block a user