Add an abstract type for RmsNorm. (#499)

This commit is contained in:
Laurent Mazare
2023-08-18 08:52:14 +01:00
committed by GitHub
parent a22b1bed7b
commit 13401df4d1
8 changed files with 45 additions and 24 deletions

View File

@ -152,7 +152,7 @@ fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> {
}
struct RmsNorm {
inner: candle_nn::LayerNorm,
inner: candle_nn::RmsNorm,
span: tracing::Span,
}

View File

@ -1,6 +1,6 @@
use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::linear_no_bias as linear;
use candle_nn::{embedding, rms_norm, Embedding, LayerNorm, Linear, VarBuilder};
use candle_nn::{embedding, rms_norm, Embedding, Linear, RmsNorm, VarBuilder};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
@ -236,14 +236,14 @@ impl Mlp {
}
struct Block {
rms_1: LayerNorm,
rms_1: RmsNorm,
attn: CausalSelfAttention,
rms_2: LayerNorm,
rms_2: RmsNorm,
mlp: Mlp,
}
impl Block {
fn new(rms_1: LayerNorm, attn: CausalSelfAttention, rms_2: LayerNorm, mlp: Mlp) -> Self {
fn new(rms_1: RmsNorm, attn: CausalSelfAttention, rms_2: RmsNorm, mlp: Mlp) -> Self {
Self {
rms_1,
attn,
@ -279,7 +279,7 @@ impl Block {
pub struct Llama {
wte: Embedding,
blocks: Vec<Block>,
ln_f: LayerNorm,
ln_f: RmsNorm,
lm_head: Linear,
pub config: Config,
}

View File

@ -231,7 +231,7 @@ fn main() -> Result<()> {
"{} token: {} '{}'",
index + 1,
next_token,
tokenizer.decode(vec![next_token], true).map_err(E::msg)?
tokenizer.decode(&[next_token], true).map_err(E::msg)?
);
}
}
@ -241,7 +241,9 @@ fn main() -> Result<()> {
"{} tokens generated ({} token/s)\n----\n{}\n----",
args.sample_len,
args.sample_len as f64 / dt.as_secs_f64(),
tokenizer.decode(new_tokens, true).map_err(E::msg)?
tokenizer
.decode(new_tokens.as_slice(), true)
.map_err(E::msg)?
);
}
Ok(())

View File

@ -1,6 +1,6 @@
use candle::backend::BackendStorage;
use candle::{CpuStorage, CustomOp1, DType, Device, IndexOp, Layout, Result, Shape, Tensor, D};
use candle_nn::{rms_norm, Embedding, LayerNorm, Linear, VarBuilder};
use candle_nn::{rms_norm, Embedding, Linear, RmsNorm, VarBuilder};
use cudarc::nccl::safe::{Comm, ReduceOp};
use half::f16;
use std::rc::Rc;
@ -336,14 +336,14 @@ impl Mlp {
}
struct Block {
rms_1: LayerNorm,
rms_1: RmsNorm,
attn: CausalSelfAttention,
rms_2: LayerNorm,
rms_2: RmsNorm,
mlp: Mlp,
}
impl Block {
fn new(rms_1: LayerNorm, attn: CausalSelfAttention, rms_2: LayerNorm, mlp: Mlp) -> Self {
fn new(rms_1: RmsNorm, attn: CausalSelfAttention, rms_2: RmsNorm, mlp: Mlp) -> Self {
Self {
rms_1,
attn,
@ -408,7 +408,7 @@ impl Llama {
pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc<Comm>) -> Result<Self> {
let wte = embedding(cfg, vb.pp("model.embed_tokens"))?;
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
let norm = RmsNorm::load(cfg.hidden_size, vb.pp("model.norm"))?;
let norm = rms_norm(cfg.hidden_size, 1e-5, vb.pp("model.norm"))?;
let blocks: Vec<_> = (0..cfg.n_layer)
.map(|i| {
Block::load(