mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Tracing for the phi model (#936)
* Add some tracing bits to mixformers. * Add the missing file. * Add the conv2d layer to with-tracing. * Improve the tracing usage.
This commit is contained in:
@ -1,3 +1,4 @@
|
||||
use crate::models::with_tracing::{linear, Embedding as E, Linear};
|
||||
/// MixFormer model.
|
||||
/// https://huggingface.co/microsoft/phi-1_5
|
||||
/// https://arxiv.org/abs/2309.05463
|
||||
@ -58,12 +59,12 @@ impl Config {
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Embedding {
|
||||
wte: candle_nn::Embedding,
|
||||
wte: E,
|
||||
}
|
||||
|
||||
impl Embedding {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let wte = candle_nn::embedding(cfg.vocab_size, cfg.n_embd, vb.pp("wte"))?;
|
||||
let wte = E::new(cfg.vocab_size, cfg.n_embd, vb.pp("wte"))?;
|
||||
Ok(Self { wte })
|
||||
}
|
||||
}
|
||||
@ -143,16 +144,16 @@ impl RotaryEmbedding {
|
||||
#[derive(Debug)]
|
||||
#[allow(clippy::upper_case_acronyms)]
|
||||
struct MLP {
|
||||
fc1: candle_nn::Linear,
|
||||
fc2: candle_nn::Linear,
|
||||
fc1: Linear,
|
||||
fc2: Linear,
|
||||
act: Activation,
|
||||
}
|
||||
|
||||
impl MLP {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let n_inner = cfg.n_inner.unwrap_or(4 * cfg.n_embd);
|
||||
let fc1 = candle_nn::linear(cfg.n_embd, n_inner, vb.pp("fc1"))?;
|
||||
let fc2 = candle_nn::linear(n_inner, cfg.n_embd, vb.pp("fc2"))?;
|
||||
let fc1 = linear(cfg.n_embd, n_inner, vb.pp("fc1"))?;
|
||||
let fc2 = linear(n_inner, cfg.n_embd, vb.pp("fc2"))?;
|
||||
Ok(Self {
|
||||
fc1,
|
||||
fc2,
|
||||
@ -170,13 +171,13 @@ impl Module for MLP {
|
||||
#[derive(Debug)]
|
||||
struct CausalLMHead {
|
||||
ln: candle_nn::LayerNorm,
|
||||
linear: candle_nn::Linear,
|
||||
linear: Linear,
|
||||
}
|
||||
|
||||
impl CausalLMHead {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let ln = candle_nn::layer_norm(cfg.n_embd, cfg.layer_norm_epsilon, vb.pp("ln"))?;
|
||||
let linear = candle_nn::linear(cfg.n_embd, cfg.vocab_size, vb.pp("linear"))?;
|
||||
let linear = linear(cfg.n_embd, cfg.vocab_size, vb.pp("linear"))?;
|
||||
Ok(Self { ln, linear })
|
||||
}
|
||||
}
|
||||
@ -192,20 +193,21 @@ impl Module for CausalLMHead {
|
||||
#[derive(Debug)]
|
||||
#[allow(clippy::upper_case_acronyms)]
|
||||
struct MHA {
|
||||
wqkv: candle_nn::Linear,
|
||||
out_proj: candle_nn::Linear,
|
||||
wqkv: Linear,
|
||||
out_proj: Linear,
|
||||
rotary_emb: RotaryEmbedding,
|
||||
kv_cache: Option<(Tensor, Tensor)>,
|
||||
head_dim: usize,
|
||||
softmax_scale: f64,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl MHA {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let head_dim = cfg.n_embd / cfg.n_head;
|
||||
let op_size = cfg.n_embd;
|
||||
let wqkv = candle_nn::linear(cfg.n_embd, 3 * op_size, vb.pp("Wqkv"))?;
|
||||
let out_proj = candle_nn::linear(op_size, cfg.n_embd, vb.pp("out_proj"))?;
|
||||
let wqkv = linear(cfg.n_embd, 3 * op_size, vb.pp("Wqkv"))?;
|
||||
let out_proj = linear(op_size, cfg.n_embd, vb.pp("out_proj"))?;
|
||||
let rotary_emb = RotaryEmbedding::new(cfg.rotary_dim, MAX_SEQ_LEN, vb.device())?;
|
||||
let softmax_scale = 1f64 / (head_dim as f64).sqrt();
|
||||
Ok(Self {
|
||||
@ -215,10 +217,12 @@ impl MHA {
|
||||
kv_cache: None,
|
||||
rotary_emb,
|
||||
softmax_scale,
|
||||
span: tracing::span!(tracing::Level::TRACE, "mha"),
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&mut self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let (b_size, seq_len, _n_embd) = xs.dims3()?;
|
||||
let qkv = self
|
||||
.wqkv
|
||||
@ -267,6 +271,7 @@ struct ParallelBlock {
|
||||
ln: candle_nn::LayerNorm,
|
||||
mixer: MHA,
|
||||
mlp: MLP,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl ParallelBlock {
|
||||
@ -274,10 +279,16 @@ impl ParallelBlock {
|
||||
let ln = candle_nn::layer_norm(cfg.n_embd, cfg.layer_norm_epsilon, vb.pp("ln"))?;
|
||||
let mixer = MHA::new(cfg, vb.pp("mixer"))?;
|
||||
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
|
||||
Ok(Self { ln, mixer, mlp })
|
||||
Ok(Self {
|
||||
ln,
|
||||
mixer,
|
||||
mlp,
|
||||
span: tracing::span!(tracing::Level::TRACE, "block"),
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&mut self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let residual = xs;
|
||||
let xs = xs.apply(&self.ln)?;
|
||||
let attn_outputs = self.mixer.forward(&xs)?;
|
||||
@ -291,6 +302,7 @@ pub struct MixFormerSequentialForCausalLM {
|
||||
embedding: Embedding,
|
||||
blocks: Vec<ParallelBlock>,
|
||||
head: CausalLMHead,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl MixFormerSequentialForCausalLM {
|
||||
@ -307,10 +319,12 @@ impl MixFormerSequentialForCausalLM {
|
||||
embedding,
|
||||
blocks,
|
||||
head,
|
||||
span: tracing::span!(tracing::Level::TRACE, "mixformer"),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let (_b_size, seq_len) = xs.dims2()?;
|
||||
let mut xs = xs.apply(&self.embedding)?;
|
||||
for block in self.blocks.iter_mut() {
|
||||
|
Reference in New Issue
Block a user