diff --git a/candle-transformers/src/models/t5.rs b/candle-transformers/src/models/t5.rs index ffa2764b..2b71fcda 100644 --- a/candle-transformers/src/models/t5.rs +++ b/candle-transformers/src/models/t5.rs @@ -321,6 +321,8 @@ struct T5Attention { use_cache: bool, kv_cache: Option<(Tensor, Tensor)>, span: tracing::Span, + span_cache: tracing::Span, + span_mm: tracing::Span, span_sm: tracing::Span, } @@ -360,6 +362,8 @@ impl T5Attention { use_cache: cfg.use_cache && decoder, kv_cache: None, span: tracing::span!(tracing::Level::TRACE, "attention"), + span_cache: tracing::span!(tracing::Level::TRACE, "attention-cache"), + span_mm: tracing::span!(tracing::Level::TRACE, "attention-mm"), span_sm: tracing::span!(tracing::Level::TRACE, "attention-sm"), }) } @@ -397,6 +401,7 @@ impl T5Attention { .contiguous()?; if self.use_cache { + let _enter = self.span_cache.enter(); if let Some((kv_cache_k, kv_cache_v)) = &self.kv_cache { k = Tensor::cat(&[kv_cache_k, &k], 2)?.contiguous()?; v = Tensor::cat(&[kv_cache_v, &v], 2)?.contiguous()?; @@ -404,7 +409,10 @@ impl T5Attention { self.kv_cache = Some((k.clone(), v.clone())); }; // TODO: Use flash_attn. - let scores = q.matmul(&k.t()?)?; + let scores = { + let _enter = self.span_mm.enter(); + q.matmul(&k.t()?)? + }; let scores = match mask { None => scores, Some(mask) => masked_fill( @@ -713,6 +721,7 @@ pub struct T5ForConditionalGeneration { shared: Arc, device: Device, span_decode: tracing::Span, + span_decode_head: tracing::Span, } impl T5ForConditionalGeneration { @@ -750,6 +759,7 @@ impl T5ForConditionalGeneration { shared, device: vb.device().clone(), span_decode: tracing::span!(tracing::Level::TRACE, "decode"), + span_decode_head: tracing::span!(tracing::Level::TRACE, "decode-head"), }) } @@ -778,9 +788,12 @@ impl T5ForConditionalGeneration { .narrow(1, decoder_output.dim(1)? - 1, 1)? .squeeze(1)?) * scaling_factor)?; - let output = match self.lm_head { - None => sequence_output.matmul(&self.shared.embeddings().t()?)?, - Some(ref lm_head) => lm_head.forward(&sequence_output)?, + let output = { + let _enter = self.span_decode_head.enter(); + match self.lm_head { + None => sequence_output.matmul(&self.shared.embeddings().t()?)?, + Some(ref lm_head) => lm_head.forward(&sequence_output)?, + } }; // TODO: Rescale output before projecting on vocab? * (self.model_dim**-0.5)