mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Add an abstract type for RmsNorm. (#499)
This commit is contained in:
@ -1,6 +1,6 @@
|
||||
use candle::backend::BackendStorage;
|
||||
use candle::{CpuStorage, CustomOp1, DType, Device, IndexOp, Layout, Result, Shape, Tensor, D};
|
||||
use candle_nn::{rms_norm, Embedding, LayerNorm, Linear, VarBuilder};
|
||||
use candle_nn::{rms_norm, Embedding, Linear, RmsNorm, VarBuilder};
|
||||
use cudarc::nccl::safe::{Comm, ReduceOp};
|
||||
use half::f16;
|
||||
use std::rc::Rc;
|
||||
@ -336,14 +336,14 @@ impl Mlp {
|
||||
}
|
||||
|
||||
struct Block {
|
||||
rms_1: LayerNorm,
|
||||
rms_1: RmsNorm,
|
||||
attn: CausalSelfAttention,
|
||||
rms_2: LayerNorm,
|
||||
rms_2: RmsNorm,
|
||||
mlp: Mlp,
|
||||
}
|
||||
|
||||
impl Block {
|
||||
fn new(rms_1: LayerNorm, attn: CausalSelfAttention, rms_2: LayerNorm, mlp: Mlp) -> Self {
|
||||
fn new(rms_1: RmsNorm, attn: CausalSelfAttention, rms_2: RmsNorm, mlp: Mlp) -> Self {
|
||||
Self {
|
||||
rms_1,
|
||||
attn,
|
||||
@ -408,7 +408,7 @@ impl Llama {
|
||||
pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc<Comm>) -> Result<Self> {
|
||||
let wte = embedding(cfg, vb.pp("model.embed_tokens"))?;
|
||||
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
|
||||
let norm = RmsNorm::load(cfg.hidden_size, vb.pp("model.norm"))?;
|
||||
let norm = rms_norm(cfg.hidden_size, 1e-5, vb.pp("model.norm"))?;
|
||||
let blocks: Vec<_> = (0..cfg.n_layer)
|
||||
.map(|i| {
|
||||
Block::load(
|
||||
|
Reference in New Issue
Block a user