From df6667ba88b0185e7943d58d32507d40ca275824 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Thu, 3 Aug 2023 13:52:22 +0100 Subject: [PATCH] Add some tracing to llama. (#318) --- candle-examples/examples/llama/main.rs | 14 ++++++++ candle-examples/examples/llama/model.rs | 43 ++++++++++++++++++++++--- 2 files changed, 53 insertions(+), 4 deletions(-) diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs index f3cf17bc..d0a55be1 100644 --- a/candle-examples/examples/llama/main.rs +++ b/candle-examples/examples/llama/main.rs @@ -111,6 +111,10 @@ struct Args { #[arg(long)] use_f32: bool, + /// Enable tracing (generates a trace-timestamp.json file). + #[arg(long)] + tracing: bool, + #[arg(long)] model_id: Option, @@ -123,8 +127,18 @@ struct Args { fn main() -> Result<()> { use tokenizers::Tokenizer; + use tracing_chrome::ChromeLayerBuilder; + use tracing_subscriber::prelude::*; let args = Args::parse(); + let _guard = if args.tracing { + println!("tracing..."); + let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); + tracing_subscriber::registry().with(chrome_layer).init(); + Some(guard) + } else { + None + }; let device = candle_examples::device(args.cpu)?; let config = if args.v1 { diff --git a/candle-examples/examples/llama/model.rs b/candle-examples/examples/llama/model.rs index ae27afc1..f5ac587e 100644 --- a/candle-examples/examples/llama/model.rs +++ b/candle-examples/examples/llama/model.rs @@ -1,5 +1,5 @@ use candle::{DType, Device, IndexOp, Result, Tensor, D}; -use candle_nn::{Embedding, Linear, VarBuilder}; +use candle_nn::{Embedding, VarBuilder}; use std::collections::HashMap; use std::sync::{Arc, Mutex}; @@ -47,6 +47,21 @@ impl Config { } } +// We wrap the `Linear` layer here to add some tracing so that it's easier to profile the resulting +// model. +#[derive(Debug)] +pub struct Linear { + inner: candle_nn::Linear, + span: tracing::Span, +} + +impl Linear { + fn forward(&self, x: &Tensor) -> Result { + let _enter = self.span.enter(); + self.inner.forward(x) + } +} + #[derive(Clone)] pub struct Cache { masks: Arc>>, @@ -106,8 +121,9 @@ fn silu(xs: &Tensor) -> Result { } fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result { - let weight = vb.get((size2, size1), "weight")?; - Ok(Linear::new(weight, None)) + let span = tracing::span!(tracing::Level::TRACE, "linear"); + let inner = candle_nn::linear_no_bias(size1, size2, vb)?; + Ok(Linear { inner, span }) } fn embedding(cfg: &Config, vb: VarBuilder) -> Result { @@ -118,15 +134,18 @@ fn embedding(cfg: &Config, vb: VarBuilder) -> Result { struct RmsNorm { scale: Tensor, eps: f64, + span: tracing::Span, } impl RmsNorm { fn load(size: usize, eps: f64, vb: VarBuilder) -> Result { + let span = tracing::span!(tracing::Level::TRACE, "rms-norm"); let scale = vb.get(size, "weight")?; - Ok(Self { scale, eps }) + Ok(Self { scale, eps, span }) } fn forward(&self, x: &Tensor) -> Result { + let _enter = self.span.enter(); let in_dtype = x.dtype(); // This is a no-op if x's dtype is already f32. let x = x.to_dtype(DType::F32)?; @@ -155,6 +174,8 @@ struct CausalSelfAttention { head_dim: usize, cache: Cache, use_flash_attn: bool, + span: tracing::Span, + span_rot: tracing::Span, } #[cfg(feature = "flash-attn")] @@ -175,6 +196,7 @@ fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result Result { + let _enter = self.span_rot.enter(); let (b_sz, _, seq_len, n_embd) = x.dims4()?; let cos = self.cache.cos.narrow(0, index_pos, seq_len)?; let sin = self.cache.sin.narrow(0, index_pos, seq_len)?; @@ -188,6 +210,7 @@ impl CausalSelfAttention { } fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result { + let _enter = self.span.enter(); let (b_sz, seq_len, n_embd) = x.dims3()?; let q = self.q_proj.forward(x)?; let k = self.k_proj.forward(x)?; @@ -269,6 +292,8 @@ impl CausalSelfAttention { } fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result { + let span = tracing::span!(tracing::Level::TRACE, "attn"); + let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot"); let size_in = cfg.hidden_size; let size_q = (cfg.hidden_size / cfg.n_head) * cfg.n_head; let size_kv = (cfg.hidden_size / cfg.n_head) * cfg.n_key_value_head; @@ -286,6 +311,8 @@ impl CausalSelfAttention { head_dim: cfg.hidden_size / cfg.n_head, cache: cache.clone(), use_flash_attn: cfg.use_flash_attn, + span, + span_rot, }) } } @@ -301,15 +328,18 @@ struct Mlp { c_fc1: Linear, c_fc2: Linear, c_proj: Linear, + span: tracing::Span, } impl Mlp { fn forward(&self, x: &Tensor) -> Result { + let _enter = self.span.enter(); let x = (silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?; self.c_proj.forward(&x) } fn load(vb: VarBuilder, cfg: &Config) -> Result { + let span = tracing::span!(tracing::Level::TRACE, "mlp"); let h_size = cfg.hidden_size; let i_size = cfg.intermediate_size; let c_fc1 = linear(h_size, i_size, vb.pp("gate_proj"))?; @@ -319,6 +349,7 @@ impl Mlp { c_fc1, c_fc2, c_proj, + span, }) } } @@ -328,10 +359,12 @@ struct Block { attn: CausalSelfAttention, rms_2: RmsNorm, mlp: Mlp, + span: tracing::Span, } impl Block { fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result { + let _enter = self.span.enter(); let residual = x; let x = self.rms_1.forward(x)?; let x = (self.attn.forward(&x, index_pos, block_idx)? + residual)?; @@ -341,6 +374,7 @@ impl Block { } fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result { + let span = tracing::span!(tracing::Level::TRACE, "block"); let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg)?; let mlp = Mlp::load(vb.pp("mlp"), cfg)?; let rms_1 = RmsNorm::load(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?; @@ -354,6 +388,7 @@ impl Block { attn, rms_2, mlp, + span, }) } }