Add an abstract type for RmsNorm. (#499)

This commit is contained in:
Laurent Mazare
2023-08-18 08:52:14 +01:00
committed by GitHub
parent a22b1bed7b
commit 13401df4d1
8 changed files with 45 additions and 24 deletions

View File

@ -1,5 +1,5 @@
use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::{rms_norm, Embedding, LayerNorm, Linear, VarBuilder};
use candle_nn::{rms_norm, Embedding, Linear, RmsNorm, VarBuilder};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
@ -213,14 +213,14 @@ impl Mlp {
}
struct Block {
rms_1: LayerNorm,
rms_1: RmsNorm,
attn: CausalSelfAttention,
rms_2: LayerNorm,
rms_2: RmsNorm,
mlp: Mlp,
}
impl Block {
fn new(rms_1: LayerNorm, attn: CausalSelfAttention, rms_2: LayerNorm, mlp: Mlp) -> Self {
fn new(rms_1: RmsNorm, attn: CausalSelfAttention, rms_2: RmsNorm, mlp: Mlp) -> Self {
Self {
rms_1,
attn,
@ -256,12 +256,12 @@ impl Block {
pub struct Llama {
wte: Embedding,
blocks: Vec<Block>,
ln_f: LayerNorm,
ln_f: RmsNorm,
lm_head: Linear,
}
impl Llama {
fn new(wte: Embedding, blocks: Vec<Block>, ln_f: LayerNorm, lm_head: Linear) -> Self {
fn new(wte: Embedding, blocks: Vec<Block>, ln_f: RmsNorm, lm_head: Linear) -> Self {
Self {
wte,
blocks,