mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Add an abstract type for RmsNorm. (#499)
This commit is contained in:
@ -1,6 +1,6 @@
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::linear_no_bias as linear;
|
||||
use candle_nn::{embedding, rms_norm, Embedding, LayerNorm, Linear, VarBuilder};
|
||||
use candle_nn::{embedding, rms_norm, Embedding, Linear, RmsNorm, VarBuilder};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
@ -236,14 +236,14 @@ impl Mlp {
|
||||
}
|
||||
|
||||
struct Block {
|
||||
rms_1: LayerNorm,
|
||||
rms_1: RmsNorm,
|
||||
attn: CausalSelfAttention,
|
||||
rms_2: LayerNorm,
|
||||
rms_2: RmsNorm,
|
||||
mlp: Mlp,
|
||||
}
|
||||
|
||||
impl Block {
|
||||
fn new(rms_1: LayerNorm, attn: CausalSelfAttention, rms_2: LayerNorm, mlp: Mlp) -> Self {
|
||||
fn new(rms_1: RmsNorm, attn: CausalSelfAttention, rms_2: RmsNorm, mlp: Mlp) -> Self {
|
||||
Self {
|
||||
rms_1,
|
||||
attn,
|
||||
@ -279,7 +279,7 @@ impl Block {
|
||||
pub struct Llama {
|
||||
wte: Embedding,
|
||||
blocks: Vec<Block>,
|
||||
ln_f: LayerNorm,
|
||||
ln_f: RmsNorm,
|
||||
lm_head: Linear,
|
||||
pub config: Config,
|
||||
}
|
||||
|
Reference in New Issue
Block a user