diff --git a/candle-transformers/src/models/llama.rs b/candle-transformers/src/models/llama.rs index a091d3eb..400351f3 100644 --- a/candle-transformers/src/models/llama.rs +++ b/candle-transformers/src/models/llama.rs @@ -1,4 +1,4 @@ -use super::with_tracing::{linear_no_bias as linear, Linear}; +use super::with_tracing::{linear_no_bias as linear, Linear, RmsNorm}; use candle::{DType, Device, IndexOp, Result, Tensor, D}; use candle_nn::{embedding, Embedding, Module, VarBuilder}; use std::collections::HashMap; @@ -133,25 +133,6 @@ impl Cache { } } -#[derive(Debug, Clone)] -struct RmsNorm { - inner: candle_nn::RmsNorm, - span: tracing::Span, -} - -impl RmsNorm { - fn load(size: usize, eps: f64, vb: VarBuilder) -> Result { - let span = tracing::span!(tracing::Level::TRACE, "rms-norm"); - let inner = candle_nn::rms_norm(size, eps, vb)?; - Ok(Self { inner, span }) - } - - fn forward(&self, x: &Tensor) -> Result { - let _enter = self.span.enter(); - self.inner.forward(x) - } -} - #[derive(Debug, Clone)] struct CausalSelfAttention { q_proj: Linear, @@ -377,8 +358,8 @@ impl Block { let span = tracing::span!(tracing::Level::TRACE, "block"); let attn = CausalSelfAttention::load(vb.pp("self_attn"), cfg)?; let mlp = Mlp::load(vb.pp("mlp"), cfg)?; - let rms_1 = RmsNorm::load(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?; - let rms_2 = RmsNorm::load( + let rms_1 = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?; + let rms_2 = RmsNorm::new( cfg.hidden_size, cfg.rms_norm_eps, vb.pp("post_attention_layernorm"), @@ -417,7 +398,7 @@ impl Llama { pub fn load(vb: VarBuilder, cfg: &Config) -> Result { let wte = embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("model.embed_tokens"))?; let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?; - let ln_f = RmsNorm::load(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?; + let ln_f = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?; let blocks: Vec<_> = (0..cfg.num_hidden_layers) .map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), cfg).unwrap()) .collect(); diff --git a/candle-transformers/src/models/mistral.rs b/candle-transformers/src/models/mistral.rs index 2809ae0a..be84f824 100644 --- a/candle-transformers/src/models/mistral.rs +++ b/candle-transformers/src/models/mistral.rs @@ -1,4 +1,4 @@ -use crate::models::with_tracing::{linear_no_bias, Linear}; +use crate::models::with_tracing::{linear_no_bias, Linear, RmsNorm}; /// Mistral LLM, https://github.com/mistralai/mistral-src use candle::{DType, Device, Module, Result, Tensor, D}; use candle_nn::{Activation, VarBuilder}; @@ -77,27 +77,6 @@ impl Config { } } -#[derive(Debug, Clone)] -struct RmsNorm { - inner: candle_nn::RmsNorm, - span: tracing::Span, -} - -impl RmsNorm { - fn new(size: usize, eps: f64, vb: VarBuilder) -> Result { - let span = tracing::span!(tracing::Level::TRACE, "rms-norm"); - let inner = candle_nn::rms_norm(size, eps, vb)?; - Ok(Self { inner, span }) - } -} - -impl Module for RmsNorm { - fn forward(&self, x: &Tensor) -> Result { - let _enter = self.span.enter(); - self.inner.forward(x) - } -} - #[derive(Debug, Clone)] struct RotaryEmbedding { sin: Tensor, diff --git a/candle-transformers/src/models/mixtral.rs b/candle-transformers/src/models/mixtral.rs index ede74d3f..f69c68e3 100644 --- a/candle-transformers/src/models/mixtral.rs +++ b/candle-transformers/src/models/mixtral.rs @@ -1,4 +1,4 @@ -use crate::models::with_tracing::{linear_no_bias, Linear}; +use crate::models::with_tracing::{linear_no_bias, Linear, RmsNorm}; /// Mixtral Model /// https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py /// https://mistral.ai/news/mixtral-of-experts/ @@ -48,27 +48,6 @@ impl Config { } } -#[derive(Debug, Clone)] -struct RmsNorm { - inner: candle_nn::RmsNorm, - span: tracing::Span, -} - -impl RmsNorm { - fn new(size: usize, eps: f64, vb: VarBuilder) -> Result { - let span = tracing::span!(tracing::Level::TRACE, "rms-norm"); - let inner = candle_nn::rms_norm(size, eps, vb)?; - Ok(Self { inner, span }) - } -} - -impl Module for RmsNorm { - fn forward(&self, x: &Tensor) -> Result { - let _enter = self.span.enter(); - self.inner.forward(x) - } -} - #[derive(Debug, Clone)] struct RotaryEmbedding { sin: Tensor, diff --git a/candle-transformers/src/models/qwen2.rs b/candle-transformers/src/models/qwen2.rs index 26431b7d..9a12eba5 100644 --- a/candle-transformers/src/models/qwen2.rs +++ b/candle-transformers/src/models/qwen2.rs @@ -1,4 +1,4 @@ -use crate::models::with_tracing::{linear, linear_no_bias, Linear}; +use crate::models::with_tracing::{linear, linear_no_bias, Linear, RmsNorm}; use candle::{DType, Device, Module, Result, Tensor, D}; use candle_nn::{Activation, VarBuilder}; use std::sync::Arc; @@ -21,27 +21,6 @@ pub struct Config { pub hidden_act: Activation, } -#[derive(Debug, Clone)] -struct RmsNorm { - inner: candle_nn::RmsNorm, - span: tracing::Span, -} - -impl RmsNorm { - fn new(size: usize, eps: f64, vb: VarBuilder) -> Result { - let span = tracing::span!(tracing::Level::TRACE, "rms-norm"); - let inner = candle_nn::rms_norm(size, eps, vb)?; - Ok(Self { inner, span }) - } -} - -impl Module for RmsNorm { - fn forward(&self, x: &Tensor) -> Result { - let _enter = self.span.enter(); - self.inner.forward(x) - } -} - #[derive(Debug, Clone)] struct RotaryEmbedding { sin: Tensor, diff --git a/candle-transformers/src/models/with_tracing.rs b/candle-transformers/src/models/with_tracing.rs index 2ffec724..1c34bfa2 100644 --- a/candle-transformers/src/models/with_tracing.rs +++ b/candle-transformers/src/models/with_tracing.rs @@ -167,3 +167,24 @@ pub fn layer_norm>( let span = tracing::span!(tracing::Level::TRACE, "layer-norm"); Ok(LayerNorm { inner, span }) } + +#[derive(Debug, Clone)] +pub struct RmsNorm { + inner: candle_nn::RmsNorm, + span: tracing::Span, +} + +impl RmsNorm { + pub fn new(size: usize, eps: f64, vb: VarBuilder) -> Result { + let span = tracing::span!(tracing::Level::TRACE, "rms-norm"); + let inner = candle_nn::rms_norm(size, eps, vb)?; + Ok(Self { inner, span }) + } +} + +impl Module for RmsNorm { + fn forward(&self, x: &Tensor) -> Result { + let _enter = self.span.enter(); + self.inner.forward(x) + } +} diff --git a/candle-transformers/src/models/yi.rs b/candle-transformers/src/models/yi.rs index 14b6feeb..99d9de1b 100644 --- a/candle-transformers/src/models/yi.rs +++ b/candle-transformers/src/models/yi.rs @@ -1,5 +1,5 @@ /// https://huggingface.co/01-ai/Yi-6B/blob/main/modeling_yi.py -use crate::models::with_tracing::{linear_no_bias, Linear}; +use crate::models::with_tracing::{linear_no_bias, Linear, RmsNorm}; use candle::{DType, Device, Module, Result, Tensor, D}; use candle_nn::{Activation, VarBuilder}; use std::sync::Arc; @@ -50,27 +50,6 @@ impl Config { } } -#[derive(Debug, Clone)] -struct RmsNorm { - inner: candle_nn::RmsNorm, - span: tracing::Span, -} - -impl RmsNorm { - fn new(size: usize, eps: f64, vb: VarBuilder) -> Result { - let span = tracing::span!(tracing::Level::TRACE, "rms-norm"); - let inner = candle_nn::rms_norm(size, eps, vb)?; - Ok(Self { inner, span }) - } -} - -impl Module for RmsNorm { - fn forward(&self, x: &Tensor) -> Result { - let _enter = self.span.enter(); - self.inner.forward(x) - } -} - #[derive(Debug, Clone)] struct RotaryEmbedding { sin: Tensor,