mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Add more t5 tracing. (#915)
This commit is contained in:
@ -321,6 +321,8 @@ struct T5Attention {
|
|||||||
use_cache: bool,
|
use_cache: bool,
|
||||||
kv_cache: Option<(Tensor, Tensor)>,
|
kv_cache: Option<(Tensor, Tensor)>,
|
||||||
span: tracing::Span,
|
span: tracing::Span,
|
||||||
|
span_cache: tracing::Span,
|
||||||
|
span_mm: tracing::Span,
|
||||||
span_sm: tracing::Span,
|
span_sm: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -360,6 +362,8 @@ impl T5Attention {
|
|||||||
use_cache: cfg.use_cache && decoder,
|
use_cache: cfg.use_cache && decoder,
|
||||||
kv_cache: None,
|
kv_cache: None,
|
||||||
span: tracing::span!(tracing::Level::TRACE, "attention"),
|
span: tracing::span!(tracing::Level::TRACE, "attention"),
|
||||||
|
span_cache: tracing::span!(tracing::Level::TRACE, "attention-cache"),
|
||||||
|
span_mm: tracing::span!(tracing::Level::TRACE, "attention-mm"),
|
||||||
span_sm: tracing::span!(tracing::Level::TRACE, "attention-sm"),
|
span_sm: tracing::span!(tracing::Level::TRACE, "attention-sm"),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -397,6 +401,7 @@ impl T5Attention {
|
|||||||
.contiguous()?;
|
.contiguous()?;
|
||||||
|
|
||||||
if self.use_cache {
|
if self.use_cache {
|
||||||
|
let _enter = self.span_cache.enter();
|
||||||
if let Some((kv_cache_k, kv_cache_v)) = &self.kv_cache {
|
if let Some((kv_cache_k, kv_cache_v)) = &self.kv_cache {
|
||||||
k = Tensor::cat(&[kv_cache_k, &k], 2)?.contiguous()?;
|
k = Tensor::cat(&[kv_cache_k, &k], 2)?.contiguous()?;
|
||||||
v = Tensor::cat(&[kv_cache_v, &v], 2)?.contiguous()?;
|
v = Tensor::cat(&[kv_cache_v, &v], 2)?.contiguous()?;
|
||||||
@ -404,7 +409,10 @@ impl T5Attention {
|
|||||||
self.kv_cache = Some((k.clone(), v.clone()));
|
self.kv_cache = Some((k.clone(), v.clone()));
|
||||||
};
|
};
|
||||||
// TODO: Use flash_attn.
|
// TODO: Use flash_attn.
|
||||||
let scores = q.matmul(&k.t()?)?;
|
let scores = {
|
||||||
|
let _enter = self.span_mm.enter();
|
||||||
|
q.matmul(&k.t()?)?
|
||||||
|
};
|
||||||
let scores = match mask {
|
let scores = match mask {
|
||||||
None => scores,
|
None => scores,
|
||||||
Some(mask) => masked_fill(
|
Some(mask) => masked_fill(
|
||||||
@ -713,6 +721,7 @@ pub struct T5ForConditionalGeneration {
|
|||||||
shared: Arc<Embedding>,
|
shared: Arc<Embedding>,
|
||||||
device: Device,
|
device: Device,
|
||||||
span_decode: tracing::Span,
|
span_decode: tracing::Span,
|
||||||
|
span_decode_head: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl T5ForConditionalGeneration {
|
impl T5ForConditionalGeneration {
|
||||||
@ -750,6 +759,7 @@ impl T5ForConditionalGeneration {
|
|||||||
shared,
|
shared,
|
||||||
device: vb.device().clone(),
|
device: vb.device().clone(),
|
||||||
span_decode: tracing::span!(tracing::Level::TRACE, "decode"),
|
span_decode: tracing::span!(tracing::Level::TRACE, "decode"),
|
||||||
|
span_decode_head: tracing::span!(tracing::Level::TRACE, "decode-head"),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -778,9 +788,12 @@ impl T5ForConditionalGeneration {
|
|||||||
.narrow(1, decoder_output.dim(1)? - 1, 1)?
|
.narrow(1, decoder_output.dim(1)? - 1, 1)?
|
||||||
.squeeze(1)?)
|
.squeeze(1)?)
|
||||||
* scaling_factor)?;
|
* scaling_factor)?;
|
||||||
let output = match self.lm_head {
|
let output = {
|
||||||
|
let _enter = self.span_decode_head.enter();
|
||||||
|
match self.lm_head {
|
||||||
None => sequence_output.matmul(&self.shared.embeddings().t()?)?,
|
None => sequence_output.matmul(&self.shared.embeddings().t()?)?,
|
||||||
Some(ref lm_head) => lm_head.forward(&sequence_output)?,
|
Some(ref lm_head) => lm_head.forward(&sequence_output)?,
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// TODO: Rescale output before projecting on vocab? * (self.model_dim**-0.5)
|
// TODO: Rescale output before projecting on vocab? * (self.model_dim**-0.5)
|
||||||
|
Reference in New Issue
Block a user