mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Add more t5 tracing. (#914)
* Add more t5 tracing. * Rever the sm change.
This commit is contained in:
@ -2,10 +2,35 @@
|
|||||||
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
|
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
|
||||||
|
|
||||||
use candle::{DType, Device, Module, Result, Tensor, D};
|
use candle::{DType, Device, Module, Result, Tensor, D};
|
||||||
use candle_nn::{embedding, Activation, Embedding, VarBuilder};
|
use candle_nn::{Activation, VarBuilder};
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct Embedding {
|
||||||
|
inner: candle_nn::Embedding,
|
||||||
|
span: tracing::Span,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Embedding {
|
||||||
|
fn new(d1: usize, d2: usize, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let inner = candle_nn::embedding(d1, d2, vb)?;
|
||||||
|
let span = tracing::span!(tracing::Level::TRACE, "embedding");
|
||||||
|
Ok(Self { inner, span })
|
||||||
|
}
|
||||||
|
|
||||||
|
fn embeddings(&self) -> &Tensor {
|
||||||
|
self.inner.embeddings()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for Embedding {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
|
self.inner.forward(xs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct Linear {
|
struct Linear {
|
||||||
inner: candle_nn::Linear,
|
inner: candle_nn::Linear,
|
||||||
@ -296,6 +321,7 @@ struct T5Attention {
|
|||||||
use_cache: bool,
|
use_cache: bool,
|
||||||
kv_cache: Option<(Tensor, Tensor)>,
|
kv_cache: Option<(Tensor, Tensor)>,
|
||||||
span: tracing::Span,
|
span: tracing::Span,
|
||||||
|
span_sm: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl T5Attention {
|
impl T5Attention {
|
||||||
@ -311,7 +337,7 @@ impl T5Attention {
|
|||||||
let v = Linear::new(cfg.d_model, inner_dim, vb.pp("v"))?;
|
let v = Linear::new(cfg.d_model, inner_dim, vb.pp("v"))?;
|
||||||
let o = Linear::new(inner_dim, cfg.d_model, vb.pp("o"))?;
|
let o = Linear::new(inner_dim, cfg.d_model, vb.pp("o"))?;
|
||||||
let relative_attention_bias = if has_relative_attention_bias {
|
let relative_attention_bias = if has_relative_attention_bias {
|
||||||
let emb = embedding(
|
let emb = Embedding::new(
|
||||||
cfg.relative_attention_num_buckets,
|
cfg.relative_attention_num_buckets,
|
||||||
cfg.num_heads,
|
cfg.num_heads,
|
||||||
vb.pp("relative_attention_bias"),
|
vb.pp("relative_attention_bias"),
|
||||||
@ -334,6 +360,7 @@ impl T5Attention {
|
|||||||
use_cache: cfg.use_cache && decoder,
|
use_cache: cfg.use_cache && decoder,
|
||||||
kv_cache: None,
|
kv_cache: None,
|
||||||
span: tracing::span!(tracing::Level::TRACE, "attention"),
|
span: tracing::span!(tracing::Level::TRACE, "attention"),
|
||||||
|
span_sm: tracing::span!(tracing::Level::TRACE, "attention-sm"),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -449,7 +476,10 @@ impl T5Attention {
|
|||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
let attn_weights = candle_nn::ops::softmax(&scores, D::Minus1)?;
|
let attn_weights = {
|
||||||
|
let _enter = self.span_sm.enter();
|
||||||
|
candle_nn::ops::softmax(&scores, D::Minus1)?
|
||||||
|
};
|
||||||
let attn_output = attn_weights.matmul(&v)?;
|
let attn_output = attn_weights.matmul(&v)?;
|
||||||
let attn_output = attn_output
|
let attn_output = attn_output
|
||||||
.transpose(1, 2)?
|
.transpose(1, 2)?
|
||||||
@ -653,7 +683,7 @@ pub struct T5EncoderModel {
|
|||||||
|
|
||||||
impl T5EncoderModel {
|
impl T5EncoderModel {
|
||||||
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
let shared = embedding(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?;
|
let shared = Embedding::new(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?;
|
||||||
let shared = Arc::new(shared);
|
let shared = Arc::new(shared);
|
||||||
let encoder = T5Stack::load(false, vb.pp("encoder"), &shared, cfg)?;
|
let encoder = T5Stack::load(false, vb.pp("encoder"), &shared, cfg)?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
@ -689,7 +719,7 @@ impl T5ForConditionalGeneration {
|
|||||||
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
assert!(cfg.is_encoder_decoder);
|
assert!(cfg.is_encoder_decoder);
|
||||||
let d_model = cfg.d_model;
|
let d_model = cfg.d_model;
|
||||||
let shared = embedding(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?;
|
let shared = Embedding::new(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?;
|
||||||
let shared = Arc::new(shared);
|
let shared = Arc::new(shared);
|
||||||
|
|
||||||
let mut encoder_cfg = cfg.clone();
|
let mut encoder_cfg = cfg.clone();
|
||||||
|
Reference in New Issue
Block a user