Use a common with_tracing::RmsNorm in a few models. (#1871)

* Add RmsNorm with tracing.

* Use with_tracing::RmsNorm in some models.
This commit is contained in:
Jani Monoses
2024-03-18 22:40:06 +02:00
committed by GitHub
parent 6a966cf9e0
commit 90fc82211f
6 changed files with 29 additions and 111 deletions

View File

@ -1,4 +1,4 @@
use super::with_tracing::{linear_no_bias as linear, Linear}; use super::with_tracing::{linear_no_bias as linear, Linear, RmsNorm};
use candle::{DType, Device, IndexOp, Result, Tensor, D}; use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::{embedding, Embedding, Module, VarBuilder}; use candle_nn::{embedding, Embedding, Module, VarBuilder};
use std::collections::HashMap; use std::collections::HashMap;
@ -133,25 +133,6 @@ impl Cache {
} }
} }
#[derive(Debug, Clone)]
struct RmsNorm {
inner: candle_nn::RmsNorm,
span: tracing::Span,
}
impl RmsNorm {
fn load(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
let inner = candle_nn::rms_norm(size, eps, vb)?;
Ok(Self { inner, span })
}
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
self.inner.forward(x)
}
}
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
struct CausalSelfAttention { struct CausalSelfAttention {
q_proj: Linear, q_proj: Linear,
@ -377,8 +358,8 @@ impl Block {
let span = tracing::span!(tracing::Level::TRACE, "block"); let span = tracing::span!(tracing::Level::TRACE, "block");
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cfg)?; let attn = CausalSelfAttention::load(vb.pp("self_attn"), cfg)?;
let mlp = Mlp::load(vb.pp("mlp"), cfg)?; let mlp = Mlp::load(vb.pp("mlp"), cfg)?;
let rms_1 = RmsNorm::load(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?; let rms_1 = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
let rms_2 = RmsNorm::load( let rms_2 = RmsNorm::new(
cfg.hidden_size, cfg.hidden_size,
cfg.rms_norm_eps, cfg.rms_norm_eps,
vb.pp("post_attention_layernorm"), vb.pp("post_attention_layernorm"),
@ -417,7 +398,7 @@ impl Llama {
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let wte = embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("model.embed_tokens"))?; let wte = embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("model.embed_tokens"))?;
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?; let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
let ln_f = RmsNorm::load(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?; let ln_f = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?;
let blocks: Vec<_> = (0..cfg.num_hidden_layers) let blocks: Vec<_> = (0..cfg.num_hidden_layers)
.map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), cfg).unwrap()) .map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), cfg).unwrap())
.collect(); .collect();

View File

@ -1,4 +1,4 @@
use crate::models::with_tracing::{linear_no_bias, Linear}; use crate::models::with_tracing::{linear_no_bias, Linear, RmsNorm};
/// Mistral LLM, https://github.com/mistralai/mistral-src /// Mistral LLM, https://github.com/mistralai/mistral-src
use candle::{DType, Device, Module, Result, Tensor, D}; use candle::{DType, Device, Module, Result, Tensor, D};
use candle_nn::{Activation, VarBuilder}; use candle_nn::{Activation, VarBuilder};
@ -77,27 +77,6 @@ impl Config {
} }
} }
#[derive(Debug, Clone)]
struct RmsNorm {
inner: candle_nn::RmsNorm,
span: tracing::Span,
}
impl RmsNorm {
fn new(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
let inner = candle_nn::rms_norm(size, eps, vb)?;
Ok(Self { inner, span })
}
}
impl Module for RmsNorm {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
self.inner.forward(x)
}
}
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
struct RotaryEmbedding { struct RotaryEmbedding {
sin: Tensor, sin: Tensor,

View File

@ -1,4 +1,4 @@
use crate::models::with_tracing::{linear_no_bias, Linear}; use crate::models::with_tracing::{linear_no_bias, Linear, RmsNorm};
/// Mixtral Model /// Mixtral Model
/// https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py /// https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py
/// https://mistral.ai/news/mixtral-of-experts/ /// https://mistral.ai/news/mixtral-of-experts/
@ -48,27 +48,6 @@ impl Config {
} }
} }
#[derive(Debug, Clone)]
struct RmsNorm {
inner: candle_nn::RmsNorm,
span: tracing::Span,
}
impl RmsNorm {
fn new(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
let inner = candle_nn::rms_norm(size, eps, vb)?;
Ok(Self { inner, span })
}
}
impl Module for RmsNorm {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
self.inner.forward(x)
}
}
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
struct RotaryEmbedding { struct RotaryEmbedding {
sin: Tensor, sin: Tensor,

View File

@ -1,4 +1,4 @@
use crate::models::with_tracing::{linear, linear_no_bias, Linear}; use crate::models::with_tracing::{linear, linear_no_bias, Linear, RmsNorm};
use candle::{DType, Device, Module, Result, Tensor, D}; use candle::{DType, Device, Module, Result, Tensor, D};
use candle_nn::{Activation, VarBuilder}; use candle_nn::{Activation, VarBuilder};
use std::sync::Arc; use std::sync::Arc;
@ -21,27 +21,6 @@ pub struct Config {
pub hidden_act: Activation, pub hidden_act: Activation,
} }
#[derive(Debug, Clone)]
struct RmsNorm {
inner: candle_nn::RmsNorm,
span: tracing::Span,
}
impl RmsNorm {
fn new(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
let inner = candle_nn::rms_norm(size, eps, vb)?;
Ok(Self { inner, span })
}
}
impl Module for RmsNorm {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
self.inner.forward(x)
}
}
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
struct RotaryEmbedding { struct RotaryEmbedding {
sin: Tensor, sin: Tensor,

View File

@ -167,3 +167,24 @@ pub fn layer_norm<C: Into<candle_nn::LayerNormConfig>>(
let span = tracing::span!(tracing::Level::TRACE, "layer-norm"); let span = tracing::span!(tracing::Level::TRACE, "layer-norm");
Ok(LayerNorm { inner, span }) Ok(LayerNorm { inner, span })
} }
#[derive(Debug, Clone)]
pub struct RmsNorm {
inner: candle_nn::RmsNorm,
span: tracing::Span,
}
impl RmsNorm {
pub fn new(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
let inner = candle_nn::rms_norm(size, eps, vb)?;
Ok(Self { inner, span })
}
}
impl Module for RmsNorm {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
self.inner.forward(x)
}
}

View File

@ -1,5 +1,5 @@
/// https://huggingface.co/01-ai/Yi-6B/blob/main/modeling_yi.py /// https://huggingface.co/01-ai/Yi-6B/blob/main/modeling_yi.py
use crate::models::with_tracing::{linear_no_bias, Linear}; use crate::models::with_tracing::{linear_no_bias, Linear, RmsNorm};
use candle::{DType, Device, Module, Result, Tensor, D}; use candle::{DType, Device, Module, Result, Tensor, D};
use candle_nn::{Activation, VarBuilder}; use candle_nn::{Activation, VarBuilder};
use std::sync::Arc; use std::sync::Arc;
@ -50,27 +50,6 @@ impl Config {
} }
} }
#[derive(Debug, Clone)]
struct RmsNorm {
inner: candle_nn::RmsNorm,
span: tracing::Span,
}
impl RmsNorm {
fn new(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
let inner = candle_nn::rms_norm(size, eps, vb)?;
Ok(Self { inner, span })
}
}
impl Module for RmsNorm {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
self.inner.forward(x)
}
}
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
struct RotaryEmbedding { struct RotaryEmbedding {
sin: Tensor, sin: Tensor,