From 9b24d89d2d637d72ff1bc52b9ea8bb7ccdd11c88 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Wed, 20 Sep 2023 15:03:35 +0100 Subject: [PATCH] Tracing mode for T5. (#913) * Tracing mode for T5. * Tracing for the linear layer. --- candle-examples/examples/t5/main.rs | 13 ++++ candle-transformers/src/models/t5.rs | 90 +++++++++++++++++++++++----- 2 files changed, 87 insertions(+), 16 deletions(-) diff --git a/candle-examples/examples/t5/main.rs b/candle-examples/examples/t5/main.rs index f5972754..348e9a55 100644 --- a/candle-examples/examples/t5/main.rs +++ b/candle-examples/examples/t5/main.rs @@ -150,7 +150,20 @@ impl T5ModelBuilder { } fn main() -> Result<()> { + use tracing_chrome::ChromeLayerBuilder; + use tracing_subscriber::prelude::*; + let args = Args::parse(); + + let _guard = if args.tracing { + println!("tracing..."); + let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); + tracing_subscriber::registry().with(chrome_layer).init(); + Some(guard) + } else { + None + }; + let (builder, mut tokenizer) = T5ModelBuilder::load(&args)?; let device = &builder.device; let tokenizer = tokenizer diff --git a/candle-transformers/src/models/t5.rs b/candle-transformers/src/models/t5.rs index fd2720d3..efb2819b 100644 --- a/candle-transformers/src/models/t5.rs +++ b/candle-transformers/src/models/t5.rs @@ -1,11 +1,32 @@ // T5 Text Encoder // https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py -use candle::{DType, Device, Result, Tensor, D}; -use candle_nn::{embedding, linear_no_bias, Activation, Embedding, Linear, Module, VarBuilder}; +use candle::{DType, Device, Module, Result, Tensor, D}; +use candle_nn::{embedding, Activation, Embedding, VarBuilder}; use serde::Deserialize; use std::sync::Arc; +#[derive(Debug)] +struct Linear { + inner: candle_nn::Linear, + span: tracing::Span, +} + +impl Linear { + fn new(d1: usize, d2: usize, vb: VarBuilder) -> Result { + let inner = candle_nn::linear_no_bias(d1, d2, vb)?; + let span = tracing::span!(tracing::Level::TRACE, "linear"); + Ok(Self { inner, span }) + } +} + +impl Module for Linear { + fn forward(&self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); + self.inner.forward(xs) + } +} + fn default_relative_attention_max_distance() -> usize { 128 } @@ -121,6 +142,7 @@ impl Config { struct T5LayerNorm { weight: Tensor, variance_epsilon: f64, + span: tracing::Span, } impl T5LayerNorm { @@ -129,10 +151,14 @@ impl T5LayerNorm { Ok(Self { weight, variance_epsilon: eps, + span: tracing::span!(tracing::Level::TRACE, "layer-norm"), }) } +} +impl Module for T5LayerNorm { fn forward(&self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); let dtype = xs.dtype(); let xs_f32 = xs.to_dtype(DType::F32)?; // variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) @@ -149,20 +175,25 @@ struct T5DenseActDense { wi: Linear, wo: Linear, act: Activation, + span: tracing::Span, } impl T5DenseActDense { fn load(vb: VarBuilder, cfg: &Config) -> Result { - let wi = linear_no_bias(cfg.d_model, cfg.d_ff, vb.pp("wi"))?; - let wo = linear_no_bias(cfg.d_ff, cfg.d_model, vb.pp("wo"))?; + let wi = Linear::new(cfg.d_model, cfg.d_ff, vb.pp("wi"))?; + let wo = Linear::new(cfg.d_ff, cfg.d_model, vb.pp("wo"))?; Ok(Self { wi, wo, act: Activation::Relu, + span: tracing::span!(tracing::Level::TRACE, "dense-act-dense"), }) } +} +impl Module for T5DenseActDense { fn forward(&self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); let xs = self.wi.forward(xs)?; let xs = self.act.forward(&xs)?; let xs = self.wo.forward(&xs)?; @@ -176,22 +207,27 @@ struct T5DenseGatedActDense { wi_1: Linear, wo: Linear, act: Activation, + span: tracing::Span, } impl T5DenseGatedActDense { fn load(vb: VarBuilder, cfg: &Config) -> Result { - let wi_0 = linear_no_bias(cfg.d_model, cfg.d_ff, vb.pp("wi_0"))?; - let wi_1 = linear_no_bias(cfg.d_model, cfg.d_ff, vb.pp("wi_1"))?; - let wo = linear_no_bias(cfg.d_ff, cfg.d_model, vb.pp("wo"))?; + let wi_0 = Linear::new(cfg.d_model, cfg.d_ff, vb.pp("wi_0"))?; + let wi_1 = Linear::new(cfg.d_model, cfg.d_ff, vb.pp("wi_1"))?; + let wo = Linear::new(cfg.d_ff, cfg.d_model, vb.pp("wo"))?; Ok(Self { wi_0, wi_1, wo, act: Activation::NewGelu, + span: tracing::span!(tracing::Level::TRACE, "dense-gated-act-dense"), }) } +} +impl Module for T5DenseGatedActDense { fn forward(&self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); let hidden_gelu = self.act.forward(&self.wi_0.forward(xs)?)?; let hidden_linear = self.wi_1.forward(xs)?; let xs = hidden_gelu.broadcast_mul(&hidden_linear)?; @@ -205,6 +241,7 @@ struct T5LayerFF { dense_act: Option, gated_dense_act: Option, layer_norm: T5LayerNorm, + span: tracing::Span, } impl T5LayerFF { @@ -226,10 +263,14 @@ impl T5LayerFF { dense_act, gated_dense_act, layer_norm, + span: tracing::span!(tracing::Level::TRACE, "layer-ff"), }) } +} +impl Module for T5LayerFF { fn forward(&self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); let ys = self.layer_norm.forward(xs)?; let ys = match &self.dense_act { Some(dense_act) => dense_act.forward(&ys)?, @@ -254,6 +295,7 @@ struct T5Attention { inner_dim: usize, use_cache: bool, kv_cache: Option<(Tensor, Tensor)>, + span: tracing::Span, } impl T5Attention { @@ -264,10 +306,10 @@ impl T5Attention { cfg: &Config, ) -> Result { let inner_dim = cfg.num_heads * cfg.d_kv; - let q = linear_no_bias(cfg.d_model, inner_dim, vb.pp("q"))?; - let k = linear_no_bias(cfg.d_model, inner_dim, vb.pp("k"))?; - let v = linear_no_bias(cfg.d_model, inner_dim, vb.pp("v"))?; - let o = linear_no_bias(inner_dim, cfg.d_model, vb.pp("o"))?; + let q = Linear::new(cfg.d_model, inner_dim, vb.pp("q"))?; + let k = Linear::new(cfg.d_model, inner_dim, vb.pp("k"))?; + let v = Linear::new(cfg.d_model, inner_dim, vb.pp("v"))?; + let o = Linear::new(inner_dim, cfg.d_model, vb.pp("o"))?; let relative_attention_bias = if has_relative_attention_bias { let emb = embedding( cfg.relative_attention_num_buckets, @@ -291,6 +333,7 @@ impl T5Attention { inner_dim, use_cache: cfg.use_cache && decoder, kv_cache: None, + span: tracing::span!(tracing::Level::TRACE, "attention"), }) } @@ -303,6 +346,7 @@ impl T5Attention { ) -> Result<(Tensor, Option)> { // Performs Self-attention (if key_value_states is None) or attention // over source sentence (provided by key_value_states). + let _enter = self.span.enter(); let kv_input = match key_value_states { None => xs, Some(key_value_states) => key_value_states, @@ -419,6 +463,7 @@ impl T5Attention { struct T5LayerSelfAttention { self_attention: T5Attention, layer_norm: T5LayerNorm, + span: tracing::Span, } impl T5LayerSelfAttention { @@ -429,6 +474,7 @@ impl T5LayerSelfAttention { Ok(Self { self_attention, layer_norm, + span: tracing::span!(tracing::Level::TRACE, "self-attn"), }) } @@ -438,6 +484,7 @@ impl T5LayerSelfAttention { position_bias: Option<&Tensor>, mask: Option<&Tensor>, ) -> Result<(Tensor, Option)> { + let _enter = self.span.enter(); let normed_xs = self.layer_norm.forward(xs)?; let (ys, position_bias) = self.self_attention @@ -451,6 +498,7 @@ impl T5LayerSelfAttention { struct T5LayerCrossAttention { cross_attention: T5Attention, layer_norm: T5LayerNorm, + span: tracing::Span, } impl T5LayerCrossAttention { @@ -461,6 +509,7 @@ impl T5LayerCrossAttention { Ok(Self { cross_attention, layer_norm, + span: tracing::span!(tracing::Level::TRACE, "cross-attn"), }) } @@ -470,6 +519,7 @@ impl T5LayerCrossAttention { position_bias: Option<&Tensor>, key_value_states: &Tensor, ) -> Result<(Tensor, Option)> { + let _enter = self.span.enter(); let normed_hidden_states = self.layer_norm.forward(hidden_states)?; let (ys, position_bias) = self.cross_attention.forward( &normed_hidden_states, @@ -487,6 +537,7 @@ struct T5Block { self_attn: T5LayerSelfAttention, cross_attn: Option, ff: T5LayerFF, + span: tracing::Span, } impl T5Block { @@ -510,6 +561,7 @@ impl T5Block { self_attn, cross_attn, ff, + span: tracing::span!(tracing::Level::TRACE, "block"), }) } @@ -519,6 +571,7 @@ impl T5Block { position_bias: Option<&Tensor>, encoder_hidden_states: Option<&Tensor>, ) -> Result<(Tensor, Option)> { + let _enter = self.span.enter(); // TODO: Cache masks let mask = match self.cross_attn.is_some() { true => { @@ -550,6 +603,7 @@ struct T5Stack { block: Vec, shared: Arc, final_layer_norm: T5LayerNorm, + span: tracing::Span, } impl T5Stack { @@ -566,6 +620,7 @@ impl T5Stack { block, shared: shared.clone(), final_layer_norm, + span: tracing::span!(tracing::Level::TRACE, "stack"), }) } @@ -574,6 +629,7 @@ impl T5Stack { input_ids: &Tensor, encoder_hidden_states: Option<&Tensor>, ) -> Result { + let _enter = self.span.enter(); let input_embeds = self.shared.as_ref().forward(input_ids)?; let mut hidden_states = input_embeds; let mut position_bias = None; @@ -592,6 +648,7 @@ impl T5Stack { pub struct T5EncoderModel { encoder: T5Stack, device: Device, + span: tracing::Span, } impl T5EncoderModel { @@ -602,10 +659,12 @@ impl T5EncoderModel { Ok(Self { encoder, device: vb.device().clone(), + span: tracing::span!(tracing::Level::TRACE, "encoder"), }) } pub fn forward(&mut self, input_ids: &Tensor) -> Result { + let _enter = self.span.enter(); self.encoder.forward(input_ids, None) } @@ -623,6 +682,7 @@ pub struct T5ForConditionalGeneration { lm_head: Option, shared: Arc, device: Device, + span_decode: tracing::Span, } impl T5ForConditionalGeneration { @@ -648,11 +708,7 @@ impl T5ForConditionalGeneration { let lm_head = if tie_word_embeddings { None } else { - Some(linear_no_bias( - cfg.d_model, - cfg.vocab_size, - vb.pp("lm_head"), - )?) + Some(Linear::new(cfg.d_model, cfg.vocab_size, vb.pp("lm_head"))?) }; Ok(Self { @@ -663,6 +719,7 @@ impl T5ForConditionalGeneration { lm_head, shared, device: vb.device().clone(), + span_decode: tracing::span!(tracing::Level::TRACE, "decode"), }) } @@ -675,6 +732,7 @@ impl T5ForConditionalGeneration { decoder_input_ids: &Tensor, encoder_output: &Tensor, ) -> Result { + let _enter = self.span_decode.enter(); let decoder_output = self .decoder .forward(decoder_input_ids, Some(encoder_output))?;