From 3a0d3e05df74c214f21fe8c2188424e475542569 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Wed, 20 Sep 2023 16:37:51 +0100 Subject: [PATCH] Add more t5 tracing. (#914) * Add more t5 tracing. * Rever the sm change. --- candle-transformers/src/models/t5.rs | 40 ++++++++++++++++++++++++---- 1 file changed, 35 insertions(+), 5 deletions(-) diff --git a/candle-transformers/src/models/t5.rs b/candle-transformers/src/models/t5.rs index efb2819b..ffa2764b 100644 --- a/candle-transformers/src/models/t5.rs +++ b/candle-transformers/src/models/t5.rs @@ -2,10 +2,35 @@ // https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py use candle::{DType, Device, Module, Result, Tensor, D}; -use candle_nn::{embedding, Activation, Embedding, VarBuilder}; +use candle_nn::{Activation, VarBuilder}; use serde::Deserialize; use std::sync::Arc; +#[derive(Debug)] +struct Embedding { + inner: candle_nn::Embedding, + span: tracing::Span, +} + +impl Embedding { + fn new(d1: usize, d2: usize, vb: VarBuilder) -> Result { + let inner = candle_nn::embedding(d1, d2, vb)?; + let span = tracing::span!(tracing::Level::TRACE, "embedding"); + Ok(Self { inner, span }) + } + + fn embeddings(&self) -> &Tensor { + self.inner.embeddings() + } +} + +impl Module for Embedding { + fn forward(&self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); + self.inner.forward(xs) + } +} + #[derive(Debug)] struct Linear { inner: candle_nn::Linear, @@ -296,6 +321,7 @@ struct T5Attention { use_cache: bool, kv_cache: Option<(Tensor, Tensor)>, span: tracing::Span, + span_sm: tracing::Span, } impl T5Attention { @@ -311,7 +337,7 @@ impl T5Attention { let v = Linear::new(cfg.d_model, inner_dim, vb.pp("v"))?; let o = Linear::new(inner_dim, cfg.d_model, vb.pp("o"))?; let relative_attention_bias = if has_relative_attention_bias { - let emb = embedding( + let emb = Embedding::new( cfg.relative_attention_num_buckets, cfg.num_heads, vb.pp("relative_attention_bias"), @@ -334,6 +360,7 @@ impl T5Attention { use_cache: cfg.use_cache && decoder, kv_cache: None, span: tracing::span!(tracing::Level::TRACE, "attention"), + span_sm: tracing::span!(tracing::Level::TRACE, "attention-sm"), }) } @@ -449,7 +476,10 @@ impl T5Attention { }, }; - let attn_weights = candle_nn::ops::softmax(&scores, D::Minus1)?; + let attn_weights = { + let _enter = self.span_sm.enter(); + candle_nn::ops::softmax(&scores, D::Minus1)? + }; let attn_output = attn_weights.matmul(&v)?; let attn_output = attn_output .transpose(1, 2)? @@ -653,7 +683,7 @@ pub struct T5EncoderModel { impl T5EncoderModel { pub fn load(vb: VarBuilder, cfg: &Config) -> Result { - let shared = embedding(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?; + let shared = Embedding::new(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?; let shared = Arc::new(shared); let encoder = T5Stack::load(false, vb.pp("encoder"), &shared, cfg)?; Ok(Self { @@ -689,7 +719,7 @@ impl T5ForConditionalGeneration { pub fn load(vb: VarBuilder, cfg: &Config) -> Result { assert!(cfg.is_encoder_decoder); let d_model = cfg.d_model; - let shared = embedding(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?; + let shared = Embedding::new(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?; let shared = Arc::new(shared); let mut encoder_cfg = cfg.clone();