mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Add more t5 tracing. (#914)
* Add more t5 tracing. * Rever the sm change.
This commit is contained in:
@ -2,10 +2,35 @@
|
||||
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
|
||||
|
||||
use candle::{DType, Device, Module, Result, Tensor, D};
|
||||
use candle_nn::{embedding, Activation, Embedding, VarBuilder};
|
||||
use candle_nn::{Activation, VarBuilder};
|
||||
use serde::Deserialize;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Embedding {
|
||||
inner: candle_nn::Embedding,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl Embedding {
|
||||
fn new(d1: usize, d2: usize, vb: VarBuilder) -> Result<Self> {
|
||||
let inner = candle_nn::embedding(d1, d2, vb)?;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "embedding");
|
||||
Ok(Self { inner, span })
|
||||
}
|
||||
|
||||
fn embeddings(&self) -> &Tensor {
|
||||
self.inner.embeddings()
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Embedding {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
self.inner.forward(xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Linear {
|
||||
inner: candle_nn::Linear,
|
||||
@ -296,6 +321,7 @@ struct T5Attention {
|
||||
use_cache: bool,
|
||||
kv_cache: Option<(Tensor, Tensor)>,
|
||||
span: tracing::Span,
|
||||
span_sm: tracing::Span,
|
||||
}
|
||||
|
||||
impl T5Attention {
|
||||
@ -311,7 +337,7 @@ impl T5Attention {
|
||||
let v = Linear::new(cfg.d_model, inner_dim, vb.pp("v"))?;
|
||||
let o = Linear::new(inner_dim, cfg.d_model, vb.pp("o"))?;
|
||||
let relative_attention_bias = if has_relative_attention_bias {
|
||||
let emb = embedding(
|
||||
let emb = Embedding::new(
|
||||
cfg.relative_attention_num_buckets,
|
||||
cfg.num_heads,
|
||||
vb.pp("relative_attention_bias"),
|
||||
@ -334,6 +360,7 @@ impl T5Attention {
|
||||
use_cache: cfg.use_cache && decoder,
|
||||
kv_cache: None,
|
||||
span: tracing::span!(tracing::Level::TRACE, "attention"),
|
||||
span_sm: tracing::span!(tracing::Level::TRACE, "attention-sm"),
|
||||
})
|
||||
}
|
||||
|
||||
@ -449,7 +476,10 @@ impl T5Attention {
|
||||
},
|
||||
};
|
||||
|
||||
let attn_weights = candle_nn::ops::softmax(&scores, D::Minus1)?;
|
||||
let attn_weights = {
|
||||
let _enter = self.span_sm.enter();
|
||||
candle_nn::ops::softmax(&scores, D::Minus1)?
|
||||
};
|
||||
let attn_output = attn_weights.matmul(&v)?;
|
||||
let attn_output = attn_output
|
||||
.transpose(1, 2)?
|
||||
@ -653,7 +683,7 @@ pub struct T5EncoderModel {
|
||||
|
||||
impl T5EncoderModel {
|
||||
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let shared = embedding(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?;
|
||||
let shared = Embedding::new(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?;
|
||||
let shared = Arc::new(shared);
|
||||
let encoder = T5Stack::load(false, vb.pp("encoder"), &shared, cfg)?;
|
||||
Ok(Self {
|
||||
@ -689,7 +719,7 @@ impl T5ForConditionalGeneration {
|
||||
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
assert!(cfg.is_encoder_decoder);
|
||||
let d_model = cfg.d_model;
|
||||
let shared = embedding(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?;
|
||||
let shared = Embedding::new(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?;
|
||||
let shared = Arc::new(shared);
|
||||
|
||||
let mut encoder_cfg = cfg.clone();
|
||||
|
Reference in New Issue
Block a user