Tracing mode for T5. (#913)

* Tracing mode for T5.

* Tracing for the linear layer.
This commit is contained in:
Laurent Mazare
2023-09-20 15:03:35 +01:00
committed by GitHub
parent fb1c2ac535
commit 9b24d89d2d
2 changed files with 87 additions and 16 deletions

View File

@ -150,7 +150,20 @@ impl T5ModelBuilder {
} }
fn main() -> Result<()> { fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse(); let args = Args::parse();
let _guard = if args.tracing {
println!("tracing...");
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
let (builder, mut tokenizer) = T5ModelBuilder::load(&args)?; let (builder, mut tokenizer) = T5ModelBuilder::load(&args)?;
let device = &builder.device; let device = &builder.device;
let tokenizer = tokenizer let tokenizer = tokenizer

View File

@ -1,11 +1,32 @@
// T5 Text Encoder // T5 Text Encoder
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py // https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
use candle::{DType, Device, Result, Tensor, D}; use candle::{DType, Device, Module, Result, Tensor, D};
use candle_nn::{embedding, linear_no_bias, Activation, Embedding, Linear, Module, VarBuilder}; use candle_nn::{embedding, Activation, Embedding, VarBuilder};
use serde::Deserialize; use serde::Deserialize;
use std::sync::Arc; use std::sync::Arc;
#[derive(Debug)]
struct Linear {
inner: candle_nn::Linear,
span: tracing::Span,
}
impl Linear {
fn new(d1: usize, d2: usize, vb: VarBuilder) -> Result<Self> {
let inner = candle_nn::linear_no_bias(d1, d2, vb)?;
let span = tracing::span!(tracing::Level::TRACE, "linear");
Ok(Self { inner, span })
}
}
impl Module for Linear {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
self.inner.forward(xs)
}
}
fn default_relative_attention_max_distance() -> usize { fn default_relative_attention_max_distance() -> usize {
128 128
} }
@ -121,6 +142,7 @@ impl Config {
struct T5LayerNorm { struct T5LayerNorm {
weight: Tensor, weight: Tensor,
variance_epsilon: f64, variance_epsilon: f64,
span: tracing::Span,
} }
impl T5LayerNorm { impl T5LayerNorm {
@ -129,10 +151,14 @@ impl T5LayerNorm {
Ok(Self { Ok(Self {
weight, weight,
variance_epsilon: eps, variance_epsilon: eps,
span: tracing::span!(tracing::Level::TRACE, "layer-norm"),
}) })
} }
}
impl Module for T5LayerNorm {
fn forward(&self, xs: &Tensor) -> Result<Tensor> { fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let dtype = xs.dtype(); let dtype = xs.dtype();
let xs_f32 = xs.to_dtype(DType::F32)?; let xs_f32 = xs.to_dtype(DType::F32)?;
// variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) // variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
@ -149,20 +175,25 @@ struct T5DenseActDense {
wi: Linear, wi: Linear,
wo: Linear, wo: Linear,
act: Activation, act: Activation,
span: tracing::Span,
} }
impl T5DenseActDense { impl T5DenseActDense {
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let wi = linear_no_bias(cfg.d_model, cfg.d_ff, vb.pp("wi"))?; let wi = Linear::new(cfg.d_model, cfg.d_ff, vb.pp("wi"))?;
let wo = linear_no_bias(cfg.d_ff, cfg.d_model, vb.pp("wo"))?; let wo = Linear::new(cfg.d_ff, cfg.d_model, vb.pp("wo"))?;
Ok(Self { Ok(Self {
wi, wi,
wo, wo,
act: Activation::Relu, act: Activation::Relu,
span: tracing::span!(tracing::Level::TRACE, "dense-act-dense"),
}) })
} }
}
impl Module for T5DenseActDense {
fn forward(&self, xs: &Tensor) -> Result<Tensor> { fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let xs = self.wi.forward(xs)?; let xs = self.wi.forward(xs)?;
let xs = self.act.forward(&xs)?; let xs = self.act.forward(&xs)?;
let xs = self.wo.forward(&xs)?; let xs = self.wo.forward(&xs)?;
@ -176,22 +207,27 @@ struct T5DenseGatedActDense {
wi_1: Linear, wi_1: Linear,
wo: Linear, wo: Linear,
act: Activation, act: Activation,
span: tracing::Span,
} }
impl T5DenseGatedActDense { impl T5DenseGatedActDense {
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let wi_0 = linear_no_bias(cfg.d_model, cfg.d_ff, vb.pp("wi_0"))?; let wi_0 = Linear::new(cfg.d_model, cfg.d_ff, vb.pp("wi_0"))?;
let wi_1 = linear_no_bias(cfg.d_model, cfg.d_ff, vb.pp("wi_1"))?; let wi_1 = Linear::new(cfg.d_model, cfg.d_ff, vb.pp("wi_1"))?;
let wo = linear_no_bias(cfg.d_ff, cfg.d_model, vb.pp("wo"))?; let wo = Linear::new(cfg.d_ff, cfg.d_model, vb.pp("wo"))?;
Ok(Self { Ok(Self {
wi_0, wi_0,
wi_1, wi_1,
wo, wo,
act: Activation::NewGelu, act: Activation::NewGelu,
span: tracing::span!(tracing::Level::TRACE, "dense-gated-act-dense"),
}) })
} }
}
impl Module for T5DenseGatedActDense {
fn forward(&self, xs: &Tensor) -> Result<Tensor> { fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let hidden_gelu = self.act.forward(&self.wi_0.forward(xs)?)?; let hidden_gelu = self.act.forward(&self.wi_0.forward(xs)?)?;
let hidden_linear = self.wi_1.forward(xs)?; let hidden_linear = self.wi_1.forward(xs)?;
let xs = hidden_gelu.broadcast_mul(&hidden_linear)?; let xs = hidden_gelu.broadcast_mul(&hidden_linear)?;
@ -205,6 +241,7 @@ struct T5LayerFF {
dense_act: Option<T5DenseActDense>, dense_act: Option<T5DenseActDense>,
gated_dense_act: Option<T5DenseGatedActDense>, gated_dense_act: Option<T5DenseGatedActDense>,
layer_norm: T5LayerNorm, layer_norm: T5LayerNorm,
span: tracing::Span,
} }
impl T5LayerFF { impl T5LayerFF {
@ -226,10 +263,14 @@ impl T5LayerFF {
dense_act, dense_act,
gated_dense_act, gated_dense_act,
layer_norm, layer_norm,
span: tracing::span!(tracing::Level::TRACE, "layer-ff"),
}) })
} }
}
impl Module for T5LayerFF {
fn forward(&self, xs: &Tensor) -> Result<Tensor> { fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let ys = self.layer_norm.forward(xs)?; let ys = self.layer_norm.forward(xs)?;
let ys = match &self.dense_act { let ys = match &self.dense_act {
Some(dense_act) => dense_act.forward(&ys)?, Some(dense_act) => dense_act.forward(&ys)?,
@ -254,6 +295,7 @@ struct T5Attention {
inner_dim: usize, inner_dim: usize,
use_cache: bool, use_cache: bool,
kv_cache: Option<(Tensor, Tensor)>, kv_cache: Option<(Tensor, Tensor)>,
span: tracing::Span,
} }
impl T5Attention { impl T5Attention {
@ -264,10 +306,10 @@ impl T5Attention {
cfg: &Config, cfg: &Config,
) -> Result<Self> { ) -> Result<Self> {
let inner_dim = cfg.num_heads * cfg.d_kv; let inner_dim = cfg.num_heads * cfg.d_kv;
let q = linear_no_bias(cfg.d_model, inner_dim, vb.pp("q"))?; let q = Linear::new(cfg.d_model, inner_dim, vb.pp("q"))?;
let k = linear_no_bias(cfg.d_model, inner_dim, vb.pp("k"))?; let k = Linear::new(cfg.d_model, inner_dim, vb.pp("k"))?;
let v = linear_no_bias(cfg.d_model, inner_dim, vb.pp("v"))?; let v = Linear::new(cfg.d_model, inner_dim, vb.pp("v"))?;
let o = linear_no_bias(inner_dim, cfg.d_model, vb.pp("o"))?; let o = Linear::new(inner_dim, cfg.d_model, vb.pp("o"))?;
let relative_attention_bias = if has_relative_attention_bias { let relative_attention_bias = if has_relative_attention_bias {
let emb = embedding( let emb = embedding(
cfg.relative_attention_num_buckets, cfg.relative_attention_num_buckets,
@ -291,6 +333,7 @@ impl T5Attention {
inner_dim, inner_dim,
use_cache: cfg.use_cache && decoder, use_cache: cfg.use_cache && decoder,
kv_cache: None, kv_cache: None,
span: tracing::span!(tracing::Level::TRACE, "attention"),
}) })
} }
@ -303,6 +346,7 @@ impl T5Attention {
) -> Result<(Tensor, Option<Tensor>)> { ) -> Result<(Tensor, Option<Tensor>)> {
// Performs Self-attention (if key_value_states is None) or attention // Performs Self-attention (if key_value_states is None) or attention
// over source sentence (provided by key_value_states). // over source sentence (provided by key_value_states).
let _enter = self.span.enter();
let kv_input = match key_value_states { let kv_input = match key_value_states {
None => xs, None => xs,
Some(key_value_states) => key_value_states, Some(key_value_states) => key_value_states,
@ -419,6 +463,7 @@ impl T5Attention {
struct T5LayerSelfAttention { struct T5LayerSelfAttention {
self_attention: T5Attention, self_attention: T5Attention,
layer_norm: T5LayerNorm, layer_norm: T5LayerNorm,
span: tracing::Span,
} }
impl T5LayerSelfAttention { impl T5LayerSelfAttention {
@ -429,6 +474,7 @@ impl T5LayerSelfAttention {
Ok(Self { Ok(Self {
self_attention, self_attention,
layer_norm, layer_norm,
span: tracing::span!(tracing::Level::TRACE, "self-attn"),
}) })
} }
@ -438,6 +484,7 @@ impl T5LayerSelfAttention {
position_bias: Option<&Tensor>, position_bias: Option<&Tensor>,
mask: Option<&Tensor>, mask: Option<&Tensor>,
) -> Result<(Tensor, Option<Tensor>)> { ) -> Result<(Tensor, Option<Tensor>)> {
let _enter = self.span.enter();
let normed_xs = self.layer_norm.forward(xs)?; let normed_xs = self.layer_norm.forward(xs)?;
let (ys, position_bias) = let (ys, position_bias) =
self.self_attention self.self_attention
@ -451,6 +498,7 @@ impl T5LayerSelfAttention {
struct T5LayerCrossAttention { struct T5LayerCrossAttention {
cross_attention: T5Attention, cross_attention: T5Attention,
layer_norm: T5LayerNorm, layer_norm: T5LayerNorm,
span: tracing::Span,
} }
impl T5LayerCrossAttention { impl T5LayerCrossAttention {
@ -461,6 +509,7 @@ impl T5LayerCrossAttention {
Ok(Self { Ok(Self {
cross_attention, cross_attention,
layer_norm, layer_norm,
span: tracing::span!(tracing::Level::TRACE, "cross-attn"),
}) })
} }
@ -470,6 +519,7 @@ impl T5LayerCrossAttention {
position_bias: Option<&Tensor>, position_bias: Option<&Tensor>,
key_value_states: &Tensor, key_value_states: &Tensor,
) -> Result<(Tensor, Option<Tensor>)> { ) -> Result<(Tensor, Option<Tensor>)> {
let _enter = self.span.enter();
let normed_hidden_states = self.layer_norm.forward(hidden_states)?; let normed_hidden_states = self.layer_norm.forward(hidden_states)?;
let (ys, position_bias) = self.cross_attention.forward( let (ys, position_bias) = self.cross_attention.forward(
&normed_hidden_states, &normed_hidden_states,
@ -487,6 +537,7 @@ struct T5Block {
self_attn: T5LayerSelfAttention, self_attn: T5LayerSelfAttention,
cross_attn: Option<T5LayerCrossAttention>, cross_attn: Option<T5LayerCrossAttention>,
ff: T5LayerFF, ff: T5LayerFF,
span: tracing::Span,
} }
impl T5Block { impl T5Block {
@ -510,6 +561,7 @@ impl T5Block {
self_attn, self_attn,
cross_attn, cross_attn,
ff, ff,
span: tracing::span!(tracing::Level::TRACE, "block"),
}) })
} }
@ -519,6 +571,7 @@ impl T5Block {
position_bias: Option<&Tensor>, position_bias: Option<&Tensor>,
encoder_hidden_states: Option<&Tensor>, encoder_hidden_states: Option<&Tensor>,
) -> Result<(Tensor, Option<Tensor>)> { ) -> Result<(Tensor, Option<Tensor>)> {
let _enter = self.span.enter();
// TODO: Cache masks // TODO: Cache masks
let mask = match self.cross_attn.is_some() { let mask = match self.cross_attn.is_some() {
true => { true => {
@ -550,6 +603,7 @@ struct T5Stack {
block: Vec<T5Block>, block: Vec<T5Block>,
shared: Arc<Embedding>, shared: Arc<Embedding>,
final_layer_norm: T5LayerNorm, final_layer_norm: T5LayerNorm,
span: tracing::Span,
} }
impl T5Stack { impl T5Stack {
@ -566,6 +620,7 @@ impl T5Stack {
block, block,
shared: shared.clone(), shared: shared.clone(),
final_layer_norm, final_layer_norm,
span: tracing::span!(tracing::Level::TRACE, "stack"),
}) })
} }
@ -574,6 +629,7 @@ impl T5Stack {
input_ids: &Tensor, input_ids: &Tensor,
encoder_hidden_states: Option<&Tensor>, encoder_hidden_states: Option<&Tensor>,
) -> Result<Tensor> { ) -> Result<Tensor> {
let _enter = self.span.enter();
let input_embeds = self.shared.as_ref().forward(input_ids)?; let input_embeds = self.shared.as_ref().forward(input_ids)?;
let mut hidden_states = input_embeds; let mut hidden_states = input_embeds;
let mut position_bias = None; let mut position_bias = None;
@ -592,6 +648,7 @@ impl T5Stack {
pub struct T5EncoderModel { pub struct T5EncoderModel {
encoder: T5Stack, encoder: T5Stack,
device: Device, device: Device,
span: tracing::Span,
} }
impl T5EncoderModel { impl T5EncoderModel {
@ -602,10 +659,12 @@ impl T5EncoderModel {
Ok(Self { Ok(Self {
encoder, encoder,
device: vb.device().clone(), device: vb.device().clone(),
span: tracing::span!(tracing::Level::TRACE, "encoder"),
}) })
} }
pub fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> { pub fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
self.encoder.forward(input_ids, None) self.encoder.forward(input_ids, None)
} }
@ -623,6 +682,7 @@ pub struct T5ForConditionalGeneration {
lm_head: Option<Linear>, lm_head: Option<Linear>,
shared: Arc<Embedding>, shared: Arc<Embedding>,
device: Device, device: Device,
span_decode: tracing::Span,
} }
impl T5ForConditionalGeneration { impl T5ForConditionalGeneration {
@ -648,11 +708,7 @@ impl T5ForConditionalGeneration {
let lm_head = if tie_word_embeddings { let lm_head = if tie_word_embeddings {
None None
} else { } else {
Some(linear_no_bias( Some(Linear::new(cfg.d_model, cfg.vocab_size, vb.pp("lm_head"))?)
cfg.d_model,
cfg.vocab_size,
vb.pp("lm_head"),
)?)
}; };
Ok(Self { Ok(Self {
@ -663,6 +719,7 @@ impl T5ForConditionalGeneration {
lm_head, lm_head,
shared, shared,
device: vb.device().clone(), device: vb.device().clone(),
span_decode: tracing::span!(tracing::Level::TRACE, "decode"),
}) })
} }
@ -675,6 +732,7 @@ impl T5ForConditionalGeneration {
decoder_input_ids: &Tensor, decoder_input_ids: &Tensor,
encoder_output: &Tensor, encoder_output: &Tensor,
) -> Result<Tensor> { ) -> Result<Tensor> {
let _enter = self.span_decode.enter();
let decoder_output = self let decoder_output = self
.decoder .decoder
.forward(decoder_input_ids, Some(encoder_output))?; .forward(decoder_input_ids, Some(encoder_output))?;