mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Tracing mode for T5. (#913)
* Tracing mode for T5. * Tracing for the linear layer.
This commit is contained in:
@ -150,7 +150,20 @@ impl T5ModelBuilder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
|
use tracing_chrome::ChromeLayerBuilder;
|
||||||
|
use tracing_subscriber::prelude::*;
|
||||||
|
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
|
|
||||||
|
let _guard = if args.tracing {
|
||||||
|
println!("tracing...");
|
||||||
|
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||||
|
tracing_subscriber::registry().with(chrome_layer).init();
|
||||||
|
Some(guard)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
let (builder, mut tokenizer) = T5ModelBuilder::load(&args)?;
|
let (builder, mut tokenizer) = T5ModelBuilder::load(&args)?;
|
||||||
let device = &builder.device;
|
let device = &builder.device;
|
||||||
let tokenizer = tokenizer
|
let tokenizer = tokenizer
|
||||||
|
@ -1,11 +1,32 @@
|
|||||||
// T5 Text Encoder
|
// T5 Text Encoder
|
||||||
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
|
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
|
||||||
|
|
||||||
use candle::{DType, Device, Result, Tensor, D};
|
use candle::{DType, Device, Module, Result, Tensor, D};
|
||||||
use candle_nn::{embedding, linear_no_bias, Activation, Embedding, Linear, Module, VarBuilder};
|
use candle_nn::{embedding, Activation, Embedding, VarBuilder};
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct Linear {
|
||||||
|
inner: candle_nn::Linear,
|
||||||
|
span: tracing::Span,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Linear {
|
||||||
|
fn new(d1: usize, d2: usize, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let inner = candle_nn::linear_no_bias(d1, d2, vb)?;
|
||||||
|
let span = tracing::span!(tracing::Level::TRACE, "linear");
|
||||||
|
Ok(Self { inner, span })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for Linear {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
|
self.inner.forward(xs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn default_relative_attention_max_distance() -> usize {
|
fn default_relative_attention_max_distance() -> usize {
|
||||||
128
|
128
|
||||||
}
|
}
|
||||||
@ -121,6 +142,7 @@ impl Config {
|
|||||||
struct T5LayerNorm {
|
struct T5LayerNorm {
|
||||||
weight: Tensor,
|
weight: Tensor,
|
||||||
variance_epsilon: f64,
|
variance_epsilon: f64,
|
||||||
|
span: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl T5LayerNorm {
|
impl T5LayerNorm {
|
||||||
@ -129,10 +151,14 @@ impl T5LayerNorm {
|
|||||||
Ok(Self {
|
Ok(Self {
|
||||||
weight,
|
weight,
|
||||||
variance_epsilon: eps,
|
variance_epsilon: eps,
|
||||||
|
span: tracing::span!(tracing::Level::TRACE, "layer-norm"),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for T5LayerNorm {
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
let dtype = xs.dtype();
|
let dtype = xs.dtype();
|
||||||
let xs_f32 = xs.to_dtype(DType::F32)?;
|
let xs_f32 = xs.to_dtype(DType::F32)?;
|
||||||
// variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
// variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
||||||
@ -149,20 +175,25 @@ struct T5DenseActDense {
|
|||||||
wi: Linear,
|
wi: Linear,
|
||||||
wo: Linear,
|
wo: Linear,
|
||||||
act: Activation,
|
act: Activation,
|
||||||
|
span: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl T5DenseActDense {
|
impl T5DenseActDense {
|
||||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
let wi = linear_no_bias(cfg.d_model, cfg.d_ff, vb.pp("wi"))?;
|
let wi = Linear::new(cfg.d_model, cfg.d_ff, vb.pp("wi"))?;
|
||||||
let wo = linear_no_bias(cfg.d_ff, cfg.d_model, vb.pp("wo"))?;
|
let wo = Linear::new(cfg.d_ff, cfg.d_model, vb.pp("wo"))?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
wi,
|
wi,
|
||||||
wo,
|
wo,
|
||||||
act: Activation::Relu,
|
act: Activation::Relu,
|
||||||
|
span: tracing::span!(tracing::Level::TRACE, "dense-act-dense"),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for T5DenseActDense {
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
let xs = self.wi.forward(xs)?;
|
let xs = self.wi.forward(xs)?;
|
||||||
let xs = self.act.forward(&xs)?;
|
let xs = self.act.forward(&xs)?;
|
||||||
let xs = self.wo.forward(&xs)?;
|
let xs = self.wo.forward(&xs)?;
|
||||||
@ -176,22 +207,27 @@ struct T5DenseGatedActDense {
|
|||||||
wi_1: Linear,
|
wi_1: Linear,
|
||||||
wo: Linear,
|
wo: Linear,
|
||||||
act: Activation,
|
act: Activation,
|
||||||
|
span: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl T5DenseGatedActDense {
|
impl T5DenseGatedActDense {
|
||||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
let wi_0 = linear_no_bias(cfg.d_model, cfg.d_ff, vb.pp("wi_0"))?;
|
let wi_0 = Linear::new(cfg.d_model, cfg.d_ff, vb.pp("wi_0"))?;
|
||||||
let wi_1 = linear_no_bias(cfg.d_model, cfg.d_ff, vb.pp("wi_1"))?;
|
let wi_1 = Linear::new(cfg.d_model, cfg.d_ff, vb.pp("wi_1"))?;
|
||||||
let wo = linear_no_bias(cfg.d_ff, cfg.d_model, vb.pp("wo"))?;
|
let wo = Linear::new(cfg.d_ff, cfg.d_model, vb.pp("wo"))?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
wi_0,
|
wi_0,
|
||||||
wi_1,
|
wi_1,
|
||||||
wo,
|
wo,
|
||||||
act: Activation::NewGelu,
|
act: Activation::NewGelu,
|
||||||
|
span: tracing::span!(tracing::Level::TRACE, "dense-gated-act-dense"),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for T5DenseGatedActDense {
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
let hidden_gelu = self.act.forward(&self.wi_0.forward(xs)?)?;
|
let hidden_gelu = self.act.forward(&self.wi_0.forward(xs)?)?;
|
||||||
let hidden_linear = self.wi_1.forward(xs)?;
|
let hidden_linear = self.wi_1.forward(xs)?;
|
||||||
let xs = hidden_gelu.broadcast_mul(&hidden_linear)?;
|
let xs = hidden_gelu.broadcast_mul(&hidden_linear)?;
|
||||||
@ -205,6 +241,7 @@ struct T5LayerFF {
|
|||||||
dense_act: Option<T5DenseActDense>,
|
dense_act: Option<T5DenseActDense>,
|
||||||
gated_dense_act: Option<T5DenseGatedActDense>,
|
gated_dense_act: Option<T5DenseGatedActDense>,
|
||||||
layer_norm: T5LayerNorm,
|
layer_norm: T5LayerNorm,
|
||||||
|
span: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl T5LayerFF {
|
impl T5LayerFF {
|
||||||
@ -226,10 +263,14 @@ impl T5LayerFF {
|
|||||||
dense_act,
|
dense_act,
|
||||||
gated_dense_act,
|
gated_dense_act,
|
||||||
layer_norm,
|
layer_norm,
|
||||||
|
span: tracing::span!(tracing::Level::TRACE, "layer-ff"),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for T5LayerFF {
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
let ys = self.layer_norm.forward(xs)?;
|
let ys = self.layer_norm.forward(xs)?;
|
||||||
let ys = match &self.dense_act {
|
let ys = match &self.dense_act {
|
||||||
Some(dense_act) => dense_act.forward(&ys)?,
|
Some(dense_act) => dense_act.forward(&ys)?,
|
||||||
@ -254,6 +295,7 @@ struct T5Attention {
|
|||||||
inner_dim: usize,
|
inner_dim: usize,
|
||||||
use_cache: bool,
|
use_cache: bool,
|
||||||
kv_cache: Option<(Tensor, Tensor)>,
|
kv_cache: Option<(Tensor, Tensor)>,
|
||||||
|
span: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl T5Attention {
|
impl T5Attention {
|
||||||
@ -264,10 +306,10 @@ impl T5Attention {
|
|||||||
cfg: &Config,
|
cfg: &Config,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let inner_dim = cfg.num_heads * cfg.d_kv;
|
let inner_dim = cfg.num_heads * cfg.d_kv;
|
||||||
let q = linear_no_bias(cfg.d_model, inner_dim, vb.pp("q"))?;
|
let q = Linear::new(cfg.d_model, inner_dim, vb.pp("q"))?;
|
||||||
let k = linear_no_bias(cfg.d_model, inner_dim, vb.pp("k"))?;
|
let k = Linear::new(cfg.d_model, inner_dim, vb.pp("k"))?;
|
||||||
let v = linear_no_bias(cfg.d_model, inner_dim, vb.pp("v"))?;
|
let v = Linear::new(cfg.d_model, inner_dim, vb.pp("v"))?;
|
||||||
let o = linear_no_bias(inner_dim, cfg.d_model, vb.pp("o"))?;
|
let o = Linear::new(inner_dim, cfg.d_model, vb.pp("o"))?;
|
||||||
let relative_attention_bias = if has_relative_attention_bias {
|
let relative_attention_bias = if has_relative_attention_bias {
|
||||||
let emb = embedding(
|
let emb = embedding(
|
||||||
cfg.relative_attention_num_buckets,
|
cfg.relative_attention_num_buckets,
|
||||||
@ -291,6 +333,7 @@ impl T5Attention {
|
|||||||
inner_dim,
|
inner_dim,
|
||||||
use_cache: cfg.use_cache && decoder,
|
use_cache: cfg.use_cache && decoder,
|
||||||
kv_cache: None,
|
kv_cache: None,
|
||||||
|
span: tracing::span!(tracing::Level::TRACE, "attention"),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -303,6 +346,7 @@ impl T5Attention {
|
|||||||
) -> Result<(Tensor, Option<Tensor>)> {
|
) -> Result<(Tensor, Option<Tensor>)> {
|
||||||
// Performs Self-attention (if key_value_states is None) or attention
|
// Performs Self-attention (if key_value_states is None) or attention
|
||||||
// over source sentence (provided by key_value_states).
|
// over source sentence (provided by key_value_states).
|
||||||
|
let _enter = self.span.enter();
|
||||||
let kv_input = match key_value_states {
|
let kv_input = match key_value_states {
|
||||||
None => xs,
|
None => xs,
|
||||||
Some(key_value_states) => key_value_states,
|
Some(key_value_states) => key_value_states,
|
||||||
@ -419,6 +463,7 @@ impl T5Attention {
|
|||||||
struct T5LayerSelfAttention {
|
struct T5LayerSelfAttention {
|
||||||
self_attention: T5Attention,
|
self_attention: T5Attention,
|
||||||
layer_norm: T5LayerNorm,
|
layer_norm: T5LayerNorm,
|
||||||
|
span: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl T5LayerSelfAttention {
|
impl T5LayerSelfAttention {
|
||||||
@ -429,6 +474,7 @@ impl T5LayerSelfAttention {
|
|||||||
Ok(Self {
|
Ok(Self {
|
||||||
self_attention,
|
self_attention,
|
||||||
layer_norm,
|
layer_norm,
|
||||||
|
span: tracing::span!(tracing::Level::TRACE, "self-attn"),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -438,6 +484,7 @@ impl T5LayerSelfAttention {
|
|||||||
position_bias: Option<&Tensor>,
|
position_bias: Option<&Tensor>,
|
||||||
mask: Option<&Tensor>,
|
mask: Option<&Tensor>,
|
||||||
) -> Result<(Tensor, Option<Tensor>)> {
|
) -> Result<(Tensor, Option<Tensor>)> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
let normed_xs = self.layer_norm.forward(xs)?;
|
let normed_xs = self.layer_norm.forward(xs)?;
|
||||||
let (ys, position_bias) =
|
let (ys, position_bias) =
|
||||||
self.self_attention
|
self.self_attention
|
||||||
@ -451,6 +498,7 @@ impl T5LayerSelfAttention {
|
|||||||
struct T5LayerCrossAttention {
|
struct T5LayerCrossAttention {
|
||||||
cross_attention: T5Attention,
|
cross_attention: T5Attention,
|
||||||
layer_norm: T5LayerNorm,
|
layer_norm: T5LayerNorm,
|
||||||
|
span: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl T5LayerCrossAttention {
|
impl T5LayerCrossAttention {
|
||||||
@ -461,6 +509,7 @@ impl T5LayerCrossAttention {
|
|||||||
Ok(Self {
|
Ok(Self {
|
||||||
cross_attention,
|
cross_attention,
|
||||||
layer_norm,
|
layer_norm,
|
||||||
|
span: tracing::span!(tracing::Level::TRACE, "cross-attn"),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -470,6 +519,7 @@ impl T5LayerCrossAttention {
|
|||||||
position_bias: Option<&Tensor>,
|
position_bias: Option<&Tensor>,
|
||||||
key_value_states: &Tensor,
|
key_value_states: &Tensor,
|
||||||
) -> Result<(Tensor, Option<Tensor>)> {
|
) -> Result<(Tensor, Option<Tensor>)> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
let normed_hidden_states = self.layer_norm.forward(hidden_states)?;
|
let normed_hidden_states = self.layer_norm.forward(hidden_states)?;
|
||||||
let (ys, position_bias) = self.cross_attention.forward(
|
let (ys, position_bias) = self.cross_attention.forward(
|
||||||
&normed_hidden_states,
|
&normed_hidden_states,
|
||||||
@ -487,6 +537,7 @@ struct T5Block {
|
|||||||
self_attn: T5LayerSelfAttention,
|
self_attn: T5LayerSelfAttention,
|
||||||
cross_attn: Option<T5LayerCrossAttention>,
|
cross_attn: Option<T5LayerCrossAttention>,
|
||||||
ff: T5LayerFF,
|
ff: T5LayerFF,
|
||||||
|
span: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl T5Block {
|
impl T5Block {
|
||||||
@ -510,6 +561,7 @@ impl T5Block {
|
|||||||
self_attn,
|
self_attn,
|
||||||
cross_attn,
|
cross_attn,
|
||||||
ff,
|
ff,
|
||||||
|
span: tracing::span!(tracing::Level::TRACE, "block"),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -519,6 +571,7 @@ impl T5Block {
|
|||||||
position_bias: Option<&Tensor>,
|
position_bias: Option<&Tensor>,
|
||||||
encoder_hidden_states: Option<&Tensor>,
|
encoder_hidden_states: Option<&Tensor>,
|
||||||
) -> Result<(Tensor, Option<Tensor>)> {
|
) -> Result<(Tensor, Option<Tensor>)> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
// TODO: Cache masks
|
// TODO: Cache masks
|
||||||
let mask = match self.cross_attn.is_some() {
|
let mask = match self.cross_attn.is_some() {
|
||||||
true => {
|
true => {
|
||||||
@ -550,6 +603,7 @@ struct T5Stack {
|
|||||||
block: Vec<T5Block>,
|
block: Vec<T5Block>,
|
||||||
shared: Arc<Embedding>,
|
shared: Arc<Embedding>,
|
||||||
final_layer_norm: T5LayerNorm,
|
final_layer_norm: T5LayerNorm,
|
||||||
|
span: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl T5Stack {
|
impl T5Stack {
|
||||||
@ -566,6 +620,7 @@ impl T5Stack {
|
|||||||
block,
|
block,
|
||||||
shared: shared.clone(),
|
shared: shared.clone(),
|
||||||
final_layer_norm,
|
final_layer_norm,
|
||||||
|
span: tracing::span!(tracing::Level::TRACE, "stack"),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -574,6 +629,7 @@ impl T5Stack {
|
|||||||
input_ids: &Tensor,
|
input_ids: &Tensor,
|
||||||
encoder_hidden_states: Option<&Tensor>,
|
encoder_hidden_states: Option<&Tensor>,
|
||||||
) -> Result<Tensor> {
|
) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
let input_embeds = self.shared.as_ref().forward(input_ids)?;
|
let input_embeds = self.shared.as_ref().forward(input_ids)?;
|
||||||
let mut hidden_states = input_embeds;
|
let mut hidden_states = input_embeds;
|
||||||
let mut position_bias = None;
|
let mut position_bias = None;
|
||||||
@ -592,6 +648,7 @@ impl T5Stack {
|
|||||||
pub struct T5EncoderModel {
|
pub struct T5EncoderModel {
|
||||||
encoder: T5Stack,
|
encoder: T5Stack,
|
||||||
device: Device,
|
device: Device,
|
||||||
|
span: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl T5EncoderModel {
|
impl T5EncoderModel {
|
||||||
@ -602,10 +659,12 @@ impl T5EncoderModel {
|
|||||||
Ok(Self {
|
Ok(Self {
|
||||||
encoder,
|
encoder,
|
||||||
device: vb.device().clone(),
|
device: vb.device().clone(),
|
||||||
|
span: tracing::span!(tracing::Level::TRACE, "encoder"),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> {
|
pub fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
self.encoder.forward(input_ids, None)
|
self.encoder.forward(input_ids, None)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -623,6 +682,7 @@ pub struct T5ForConditionalGeneration {
|
|||||||
lm_head: Option<Linear>,
|
lm_head: Option<Linear>,
|
||||||
shared: Arc<Embedding>,
|
shared: Arc<Embedding>,
|
||||||
device: Device,
|
device: Device,
|
||||||
|
span_decode: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl T5ForConditionalGeneration {
|
impl T5ForConditionalGeneration {
|
||||||
@ -648,11 +708,7 @@ impl T5ForConditionalGeneration {
|
|||||||
let lm_head = if tie_word_embeddings {
|
let lm_head = if tie_word_embeddings {
|
||||||
None
|
None
|
||||||
} else {
|
} else {
|
||||||
Some(linear_no_bias(
|
Some(Linear::new(cfg.d_model, cfg.vocab_size, vb.pp("lm_head"))?)
|
||||||
cfg.d_model,
|
|
||||||
cfg.vocab_size,
|
|
||||||
vb.pp("lm_head"),
|
|
||||||
)?)
|
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
@ -663,6 +719,7 @@ impl T5ForConditionalGeneration {
|
|||||||
lm_head,
|
lm_head,
|
||||||
shared,
|
shared,
|
||||||
device: vb.device().clone(),
|
device: vb.device().clone(),
|
||||||
|
span_decode: tracing::span!(tracing::Level::TRACE, "decode"),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -675,6 +732,7 @@ impl T5ForConditionalGeneration {
|
|||||||
decoder_input_ids: &Tensor,
|
decoder_input_ids: &Tensor,
|
||||||
encoder_output: &Tensor,
|
encoder_output: &Tensor,
|
||||||
) -> Result<Tensor> {
|
) -> Result<Tensor> {
|
||||||
|
let _enter = self.span_decode.enter();
|
||||||
let decoder_output = self
|
let decoder_output = self
|
||||||
.decoder
|
.decoder
|
||||||
.forward(decoder_input_ids, Some(encoder_output))?;
|
.forward(decoder_input_ids, Some(encoder_output))?;
|
||||||
|
Reference in New Issue
Block a user