diff --git a/candle-examples/examples/phi/main.rs b/candle-examples/examples/phi/main.rs index 4b290cd8..25c7db98 100644 --- a/candle-examples/examples/phi/main.rs +++ b/candle-examples/examples/phi/main.rs @@ -70,7 +70,7 @@ impl TextGeneration { } let dt = start_gen.elapsed(); println!( - "{sample_len} tokens generated ({:.3} token/s)", + "\n{sample_len} tokens generated ({:.2} token/s)", sample_len as f64 / dt.as_secs_f64(), ); Ok(()) @@ -84,6 +84,10 @@ struct Args { #[arg(long)] cpu: bool, + /// Enable tracing (generates a trace-timestamp.json file). + #[arg(long)] + tracing: bool, + #[arg(long)] prompt: String, @@ -114,8 +118,19 @@ struct Args { } fn main() -> Result<()> { + use tracing_chrome::ChromeLayerBuilder; + use tracing_subscriber::prelude::*; + let args = Args::parse(); + let _guard = if args.tracing { + let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); + tracing_subscriber::registry().with(chrome_layer).init(); + Some(guard) + } else { + None + }; + let start = std::time::Instant::now(); let api = Api::new()?; let repo = api.repo(Repo::with_revision( diff --git a/candle-transformers/src/models/mixformer.rs b/candle-transformers/src/models/mixformer.rs index 028c3567..61eaea54 100644 --- a/candle-transformers/src/models/mixformer.rs +++ b/candle-transformers/src/models/mixformer.rs @@ -1,3 +1,4 @@ +use crate::models::with_tracing::{linear, Embedding as E, Linear}; /// MixFormer model. /// https://huggingface.co/microsoft/phi-1_5 /// https://arxiv.org/abs/2309.05463 @@ -58,12 +59,12 @@ impl Config { #[derive(Debug)] struct Embedding { - wte: candle_nn::Embedding, + wte: E, } impl Embedding { fn new(cfg: &Config, vb: VarBuilder) -> Result { - let wte = candle_nn::embedding(cfg.vocab_size, cfg.n_embd, vb.pp("wte"))?; + let wte = E::new(cfg.vocab_size, cfg.n_embd, vb.pp("wte"))?; Ok(Self { wte }) } } @@ -143,16 +144,16 @@ impl RotaryEmbedding { #[derive(Debug)] #[allow(clippy::upper_case_acronyms)] struct MLP { - fc1: candle_nn::Linear, - fc2: candle_nn::Linear, + fc1: Linear, + fc2: Linear, act: Activation, } impl MLP { fn new(cfg: &Config, vb: VarBuilder) -> Result { let n_inner = cfg.n_inner.unwrap_or(4 * cfg.n_embd); - let fc1 = candle_nn::linear(cfg.n_embd, n_inner, vb.pp("fc1"))?; - let fc2 = candle_nn::linear(n_inner, cfg.n_embd, vb.pp("fc2"))?; + let fc1 = linear(cfg.n_embd, n_inner, vb.pp("fc1"))?; + let fc2 = linear(n_inner, cfg.n_embd, vb.pp("fc2"))?; Ok(Self { fc1, fc2, @@ -170,13 +171,13 @@ impl Module for MLP { #[derive(Debug)] struct CausalLMHead { ln: candle_nn::LayerNorm, - linear: candle_nn::Linear, + linear: Linear, } impl CausalLMHead { fn new(cfg: &Config, vb: VarBuilder) -> Result { let ln = candle_nn::layer_norm(cfg.n_embd, cfg.layer_norm_epsilon, vb.pp("ln"))?; - let linear = candle_nn::linear(cfg.n_embd, cfg.vocab_size, vb.pp("linear"))?; + let linear = linear(cfg.n_embd, cfg.vocab_size, vb.pp("linear"))?; Ok(Self { ln, linear }) } } @@ -192,20 +193,21 @@ impl Module for CausalLMHead { #[derive(Debug)] #[allow(clippy::upper_case_acronyms)] struct MHA { - wqkv: candle_nn::Linear, - out_proj: candle_nn::Linear, + wqkv: Linear, + out_proj: Linear, rotary_emb: RotaryEmbedding, kv_cache: Option<(Tensor, Tensor)>, head_dim: usize, softmax_scale: f64, + span: tracing::Span, } impl MHA { fn new(cfg: &Config, vb: VarBuilder) -> Result { let head_dim = cfg.n_embd / cfg.n_head; let op_size = cfg.n_embd; - let wqkv = candle_nn::linear(cfg.n_embd, 3 * op_size, vb.pp("Wqkv"))?; - let out_proj = candle_nn::linear(op_size, cfg.n_embd, vb.pp("out_proj"))?; + let wqkv = linear(cfg.n_embd, 3 * op_size, vb.pp("Wqkv"))?; + let out_proj = linear(op_size, cfg.n_embd, vb.pp("out_proj"))?; let rotary_emb = RotaryEmbedding::new(cfg.rotary_dim, MAX_SEQ_LEN, vb.device())?; let softmax_scale = 1f64 / (head_dim as f64).sqrt(); Ok(Self { @@ -215,10 +217,12 @@ impl MHA { kv_cache: None, rotary_emb, softmax_scale, + span: tracing::span!(tracing::Level::TRACE, "mha"), }) } fn forward(&mut self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); let (b_size, seq_len, _n_embd) = xs.dims3()?; let qkv = self .wqkv @@ -267,6 +271,7 @@ struct ParallelBlock { ln: candle_nn::LayerNorm, mixer: MHA, mlp: MLP, + span: tracing::Span, } impl ParallelBlock { @@ -274,10 +279,16 @@ impl ParallelBlock { let ln = candle_nn::layer_norm(cfg.n_embd, cfg.layer_norm_epsilon, vb.pp("ln"))?; let mixer = MHA::new(cfg, vb.pp("mixer"))?; let mlp = MLP::new(cfg, vb.pp("mlp"))?; - Ok(Self { ln, mixer, mlp }) + Ok(Self { + ln, + mixer, + mlp, + span: tracing::span!(tracing::Level::TRACE, "block"), + }) } fn forward(&mut self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); let residual = xs; let xs = xs.apply(&self.ln)?; let attn_outputs = self.mixer.forward(&xs)?; @@ -291,6 +302,7 @@ pub struct MixFormerSequentialForCausalLM { embedding: Embedding, blocks: Vec, head: CausalLMHead, + span: tracing::Span, } impl MixFormerSequentialForCausalLM { @@ -307,10 +319,12 @@ impl MixFormerSequentialForCausalLM { embedding, blocks, head, + span: tracing::span!(tracing::Level::TRACE, "mixformer"), }) } pub fn forward(&mut self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); let (_b_size, seq_len) = xs.dims2()?; let mut xs = xs.apply(&self.embedding)?; for block in self.blocks.iter_mut() { diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index 991ee201..0fbcaa07 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -11,4 +11,5 @@ pub mod segment_anything; pub mod stable_diffusion; pub mod t5; pub mod whisper; +pub mod with_tracing; pub mod wuerstchen; diff --git a/candle-transformers/src/models/stable_diffusion/resnet.rs b/candle-transformers/src/models/stable_diffusion/resnet.rs index 0d818115..5df04a8b 100644 --- a/candle-transformers/src/models/stable_diffusion/resnet.rs +++ b/candle-transformers/src/models/stable_diffusion/resnet.rs @@ -4,7 +4,7 @@ //! //! Denoising Diffusion Implicit Models, K. He and al, 2015. //! https://arxiv.org/abs/1512.03385 -use super::utils::{conv2d, Conv2d}; +use crate::models::with_tracing::{conv2d, Conv2d}; use candle::{Result, Tensor, D}; use candle_nn as nn; use candle_nn::Module; diff --git a/candle-transformers/src/models/stable_diffusion/unet_2d.rs b/candle-transformers/src/models/stable_diffusion/unet_2d.rs index a3ed136e..f23bd425 100644 --- a/candle-transformers/src/models/stable_diffusion/unet_2d.rs +++ b/candle-transformers/src/models/stable_diffusion/unet_2d.rs @@ -4,7 +4,7 @@ //! timestep and return a denoised version of the input. use super::embeddings::{TimestepEmbedding, Timesteps}; use super::unet_2d_blocks::*; -use super::utils::{conv2d, Conv2d}; +use crate::models::with_tracing::{conv2d, Conv2d}; use candle::{Result, Tensor}; use candle_nn as nn; use candle_nn::Module; diff --git a/candle-transformers/src/models/stable_diffusion/unet_2d_blocks.rs b/candle-transformers/src/models/stable_diffusion/unet_2d_blocks.rs index 29510cef..18448427 100644 --- a/candle-transformers/src/models/stable_diffusion/unet_2d_blocks.rs +++ b/candle-transformers/src/models/stable_diffusion/unet_2d_blocks.rs @@ -4,7 +4,7 @@ use super::attention::{ AttentionBlock, AttentionBlockConfig, SpatialTransformer, SpatialTransformerConfig, }; use super::resnet::{ResnetBlock2D, ResnetBlock2DConfig}; -use super::utils::{conv2d, Conv2d}; +use crate::models::with_tracing::{conv2d, Conv2d}; use candle::{Module, Result, Tensor, D}; use candle_nn as nn; diff --git a/candle-transformers/src/models/stable_diffusion/utils.rs b/candle-transformers/src/models/stable_diffusion/utils.rs index c62f17af..0c95cfef 100644 --- a/candle-transformers/src/models/stable_diffusion/utils.rs +++ b/candle-transformers/src/models/stable_diffusion/utils.rs @@ -1,5 +1,4 @@ use candle::{Device, Result, Tensor}; -use candle_nn::Module; pub fn linspace(start: f64, stop: f64, steps: usize) -> Result { if steps < 1 { @@ -11,29 +10,3 @@ pub fn linspace(start: f64, stop: f64, steps: usize) -> Result { .collect::>(); Tensor::from_vec(vs, steps, &Device::Cpu) } - -// Wrap the conv2d op to provide some tracing. -#[derive(Debug)] -pub struct Conv2d { - inner: candle_nn::Conv2d, - span: tracing::Span, -} - -impl Conv2d { - pub fn forward(&self, x: &Tensor) -> Result { - let _enter = self.span.enter(); - self.inner.forward(x) - } -} - -pub fn conv2d( - in_channels: usize, - out_channels: usize, - kernel_size: usize, - cfg: candle_nn::Conv2dConfig, - vs: candle_nn::VarBuilder, -) -> Result { - let span = tracing::span!(tracing::Level::TRACE, "conv2d"); - let inner = candle_nn::conv2d(in_channels, out_channels, kernel_size, cfg, vs)?; - Ok(Conv2d { inner, span }) -} diff --git a/candle-transformers/src/models/t5.rs b/candle-transformers/src/models/t5.rs index 539ae89b..c5d5724a 100644 --- a/candle-transformers/src/models/t5.rs +++ b/candle-transformers/src/models/t5.rs @@ -1,57 +1,12 @@ // T5 Text Model // https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py +use crate::models::with_tracing::{linear_no_bias, Embedding, Linear}; use candle::{DType, Device, Module, Result, Tensor, D}; use candle_nn::{Activation, VarBuilder}; use serde::Deserialize; use std::sync::Arc; -#[derive(Debug)] -struct Embedding { - inner: candle_nn::Embedding, - span: tracing::Span, -} - -impl Embedding { - fn new(d1: usize, d2: usize, vb: VarBuilder) -> Result { - let inner = candle_nn::embedding(d1, d2, vb)?; - let span = tracing::span!(tracing::Level::TRACE, "embedding"); - Ok(Self { inner, span }) - } - - fn embeddings(&self) -> &Tensor { - self.inner.embeddings() - } -} - -impl Module for Embedding { - fn forward(&self, xs: &Tensor) -> Result { - let _enter = self.span.enter(); - self.inner.forward(xs) - } -} - -#[derive(Debug)] -struct Linear { - inner: candle_nn::Linear, - span: tracing::Span, -} - -impl Linear { - fn new(d1: usize, d2: usize, vb: VarBuilder) -> Result { - let inner = candle_nn::linear_no_bias(d1, d2, vb)?; - let span = tracing::span!(tracing::Level::TRACE, "linear"); - Ok(Self { inner, span }) - } -} - -impl Module for Linear { - fn forward(&self, xs: &Tensor) -> Result { - let _enter = self.span.enter(); - self.inner.forward(xs) - } -} - fn default_relative_attention_max_distance() -> usize { 128 } @@ -205,8 +160,8 @@ struct T5DenseActDense { impl T5DenseActDense { fn load(vb: VarBuilder, cfg: &Config) -> Result { - let wi = Linear::new(cfg.d_model, cfg.d_ff, vb.pp("wi"))?; - let wo = Linear::new(cfg.d_ff, cfg.d_model, vb.pp("wo"))?; + let wi = linear_no_bias(cfg.d_model, cfg.d_ff, vb.pp("wi"))?; + let wo = linear_no_bias(cfg.d_ff, cfg.d_model, vb.pp("wo"))?; Ok(Self { wi, wo, @@ -237,9 +192,9 @@ struct T5DenseGatedActDense { impl T5DenseGatedActDense { fn load(vb: VarBuilder, cfg: &Config) -> Result { - let wi_0 = Linear::new(cfg.d_model, cfg.d_ff, vb.pp("wi_0"))?; - let wi_1 = Linear::new(cfg.d_model, cfg.d_ff, vb.pp("wi_1"))?; - let wo = Linear::new(cfg.d_ff, cfg.d_model, vb.pp("wo"))?; + let wi_0 = linear_no_bias(cfg.d_model, cfg.d_ff, vb.pp("wi_0"))?; + let wi_1 = linear_no_bias(cfg.d_model, cfg.d_ff, vb.pp("wi_1"))?; + let wo = linear_no_bias(cfg.d_ff, cfg.d_model, vb.pp("wo"))?; Ok(Self { wi_0, wi_1, @@ -334,10 +289,10 @@ impl T5Attention { cfg: &Config, ) -> Result { let inner_dim = cfg.num_heads * cfg.d_kv; - let q = Linear::new(cfg.d_model, inner_dim, vb.pp("q"))?; - let k = Linear::new(cfg.d_model, inner_dim, vb.pp("k"))?; - let v = Linear::new(cfg.d_model, inner_dim, vb.pp("v"))?; - let o = Linear::new(inner_dim, cfg.d_model, vb.pp("o"))?; + let q = linear_no_bias(cfg.d_model, inner_dim, vb.pp("q"))?; + let k = linear_no_bias(cfg.d_model, inner_dim, vb.pp("k"))?; + let v = linear_no_bias(cfg.d_model, inner_dim, vb.pp("v"))?; + let o = linear_no_bias(inner_dim, cfg.d_model, vb.pp("o"))?; let relative_attention_bias = if has_relative_attention_bias { let emb = Embedding::new( cfg.relative_attention_num_buckets, @@ -772,7 +727,11 @@ impl T5ForConditionalGeneration { let lm_head = if tie_word_embeddings { None } else { - Some(Linear::new(cfg.d_model, cfg.vocab_size, vb.pp("lm_head"))?) + Some(linear_no_bias( + cfg.d_model, + cfg.vocab_size, + vb.pp("lm_head"), + )?) }; Ok(Self { diff --git a/candle-transformers/src/models/with_tracing.rs b/candle-transformers/src/models/with_tracing.rs new file mode 100644 index 00000000..0a2d65b9 --- /dev/null +++ b/candle-transformers/src/models/with_tracing.rs @@ -0,0 +1,78 @@ +use candle::{Module, Result, Tensor}; +use candle_nn::VarBuilder; + +#[derive(Debug)] +pub struct Embedding { + inner: candle_nn::Embedding, + span: tracing::Span, +} + +impl Embedding { + pub fn new(d1: usize, d2: usize, vb: VarBuilder) -> Result { + let inner = candle_nn::embedding(d1, d2, vb)?; + let span = tracing::span!(tracing::Level::TRACE, "embedding"); + Ok(Self { inner, span }) + } + + pub fn embeddings(&self) -> &Tensor { + self.inner.embeddings() + } +} + +impl Module for Embedding { + fn forward(&self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); + self.inner.forward(xs) + } +} + +#[derive(Debug)] +pub struct Linear { + inner: candle_nn::Linear, + span: tracing::Span, +} + +pub fn linear(d1: usize, d2: usize, vb: VarBuilder) -> Result { + let inner = candle_nn::linear(d1, d2, vb)?; + let span = tracing::span!(tracing::Level::TRACE, "linear"); + Ok(Linear { inner, span }) +} + +pub fn linear_no_bias(d1: usize, d2: usize, vb: VarBuilder) -> Result { + let inner = candle_nn::linear_no_bias(d1, d2, vb)?; + let span = tracing::span!(tracing::Level::TRACE, "linear"); + Ok(Linear { inner, span }) +} + +impl Module for Linear { + fn forward(&self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); + self.inner.forward(xs) + } +} + +// Wrap the conv2d op to provide some tracing. +#[derive(Debug)] +pub struct Conv2d { + inner: candle_nn::Conv2d, + span: tracing::Span, +} + +impl Conv2d { + pub fn forward(&self, x: &Tensor) -> Result { + let _enter = self.span.enter(); + self.inner.forward(x) + } +} + +pub fn conv2d( + in_channels: usize, + out_channels: usize, + kernel_size: usize, + cfg: candle_nn::Conv2dConfig, + vs: candle_nn::VarBuilder, +) -> Result { + let span = tracing::span!(tracing::Level::TRACE, "conv2d"); + let inner = candle_nn::conv2d(in_channels, out_channels, kernel_size, cfg, vs)?; + Ok(Conv2d { inner, span }) +}