Add more t5 tracing. (#914)

* Add more t5 tracing.

* Rever the sm change.
This commit is contained in:
Laurent Mazare
2023-09-20 16:37:51 +01:00
committed by GitHub
parent 9b24d89d2d
commit 3a0d3e05df

View File

@ -2,10 +2,35 @@
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
use candle::{DType, Device, Module, Result, Tensor, D};
use candle_nn::{embedding, Activation, Embedding, VarBuilder};
use candle_nn::{Activation, VarBuilder};
use serde::Deserialize;
use std::sync::Arc;
#[derive(Debug)]
struct Embedding {
inner: candle_nn::Embedding,
span: tracing::Span,
}
impl Embedding {
fn new(d1: usize, d2: usize, vb: VarBuilder) -> Result<Self> {
let inner = candle_nn::embedding(d1, d2, vb)?;
let span = tracing::span!(tracing::Level::TRACE, "embedding");
Ok(Self { inner, span })
}
fn embeddings(&self) -> &Tensor {
self.inner.embeddings()
}
}
impl Module for Embedding {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
self.inner.forward(xs)
}
}
#[derive(Debug)]
struct Linear {
inner: candle_nn::Linear,
@ -296,6 +321,7 @@ struct T5Attention {
use_cache: bool,
kv_cache: Option<(Tensor, Tensor)>,
span: tracing::Span,
span_sm: tracing::Span,
}
impl T5Attention {
@ -311,7 +337,7 @@ impl T5Attention {
let v = Linear::new(cfg.d_model, inner_dim, vb.pp("v"))?;
let o = Linear::new(inner_dim, cfg.d_model, vb.pp("o"))?;
let relative_attention_bias = if has_relative_attention_bias {
let emb = embedding(
let emb = Embedding::new(
cfg.relative_attention_num_buckets,
cfg.num_heads,
vb.pp("relative_attention_bias"),
@ -334,6 +360,7 @@ impl T5Attention {
use_cache: cfg.use_cache && decoder,
kv_cache: None,
span: tracing::span!(tracing::Level::TRACE, "attention"),
span_sm: tracing::span!(tracing::Level::TRACE, "attention-sm"),
})
}
@ -449,7 +476,10 @@ impl T5Attention {
},
};
let attn_weights = candle_nn::ops::softmax(&scores, D::Minus1)?;
let attn_weights = {
let _enter = self.span_sm.enter();
candle_nn::ops::softmax(&scores, D::Minus1)?
};
let attn_output = attn_weights.matmul(&v)?;
let attn_output = attn_output
.transpose(1, 2)?
@ -653,7 +683,7 @@ pub struct T5EncoderModel {
impl T5EncoderModel {
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let shared = embedding(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?;
let shared = Embedding::new(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?;
let shared = Arc::new(shared);
let encoder = T5Stack::load(false, vb.pp("encoder"), &shared, cfg)?;
Ok(Self {
@ -689,7 +719,7 @@ impl T5ForConditionalGeneration {
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
assert!(cfg.is_encoder_decoder);
let d_model = cfg.d_model;
let shared = embedding(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?;
let shared = Embedding::new(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?;
let shared = Arc::new(shared);
let mut encoder_cfg = cfg.clone();