mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Use a common with_tracing::RmsNorm in a few models. (#1871)
* Add RmsNorm with tracing. * Use with_tracing::RmsNorm in some models.
This commit is contained in:
@ -1,4 +1,4 @@
|
||||
use super::with_tracing::{linear_no_bias as linear, Linear};
|
||||
use super::with_tracing::{linear_no_bias as linear, Linear, RmsNorm};
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{embedding, Embedding, Module, VarBuilder};
|
||||
use std::collections::HashMap;
|
||||
@ -133,25 +133,6 @@ impl Cache {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct RmsNorm {
|
||||
inner: candle_nn::RmsNorm,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl RmsNorm {
|
||||
fn load(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
|
||||
let inner = candle_nn::rms_norm(size, eps, vb)?;
|
||||
Ok(Self { inner, span })
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
self.inner.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct CausalSelfAttention {
|
||||
q_proj: Linear,
|
||||
@ -377,8 +358,8 @@ impl Block {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "block");
|
||||
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cfg)?;
|
||||
let mlp = Mlp::load(vb.pp("mlp"), cfg)?;
|
||||
let rms_1 = RmsNorm::load(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
|
||||
let rms_2 = RmsNorm::load(
|
||||
let rms_1 = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
|
||||
let rms_2 = RmsNorm::new(
|
||||
cfg.hidden_size,
|
||||
cfg.rms_norm_eps,
|
||||
vb.pp("post_attention_layernorm"),
|
||||
@ -417,7 +398,7 @@ impl Llama {
|
||||
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let wte = embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("model.embed_tokens"))?;
|
||||
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
|
||||
let ln_f = RmsNorm::load(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?;
|
||||
let ln_f = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?;
|
||||
let blocks: Vec<_> = (0..cfg.num_hidden_layers)
|
||||
.map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), cfg).unwrap())
|
||||
.collect();
|
||||
|
Reference in New Issue
Block a user