mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 03:28:50 +00:00
Add some tracing to metavoice. (#1826)
This commit is contained in:
@ -181,6 +181,7 @@ pub mod tokenizers {
|
|||||||
pub end_of_text: usize,
|
pub end_of_text: usize,
|
||||||
pub offset: usize,
|
pub offset: usize,
|
||||||
pub ranks: HashMap<Vec<u8>, Rank>,
|
pub ranks: HashMap<Vec<u8>, Rank>,
|
||||||
|
span: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl BPE {
|
impl BPE {
|
||||||
@ -231,6 +232,7 @@ pub mod tokenizers {
|
|||||||
end_of_text,
|
end_of_text,
|
||||||
offset,
|
offset,
|
||||||
ranks,
|
ranks,
|
||||||
|
span: tracing::span!(tracing::Level::TRACE, "bpe"),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -310,6 +312,7 @@ pub mod tokenizers {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn encode(&self, text: &str) -> Result<Vec<u32>> {
|
pub fn encode(&self, text: &str) -> Result<Vec<u32>> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
let mut bpe_tokens: Vec<u32> = Vec::new();
|
let mut bpe_tokens: Vec<u32> = Vec::new();
|
||||||
for word in self.re.find_iter(text) {
|
for word in self.re.find_iter(text) {
|
||||||
let word = word.map_err(E::wrap)?;
|
let word = word.map_err(E::wrap)?;
|
||||||
@ -426,6 +429,7 @@ pub mod gpt {
|
|||||||
c_attn: Linear,
|
c_attn: Linear,
|
||||||
c_proj: Linear,
|
c_proj: Linear,
|
||||||
n_head: usize,
|
n_head: usize,
|
||||||
|
span: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl SelfAttention {
|
impl SelfAttention {
|
||||||
@ -444,12 +448,14 @@ pub mod gpt {
|
|||||||
c_attn,
|
c_attn,
|
||||||
c_proj,
|
c_proj,
|
||||||
n_head: cfg.n_head,
|
n_head: cfg.n_head,
|
||||||
|
span: tracing::span!(tracing::Level::TRACE, "self-attn"),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Module for SelfAttention {
|
impl Module for SelfAttention {
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
let (b, t, c) = xs.dims3()?;
|
let (b, t, c) = xs.dims3()?;
|
||||||
let c_x = xs
|
let c_x = xs
|
||||||
.apply(&self.c_attn)?
|
.apply(&self.c_attn)?
|
||||||
@ -474,11 +480,13 @@ pub mod gpt {
|
|||||||
Gelu {
|
Gelu {
|
||||||
c_fc: Linear,
|
c_fc: Linear,
|
||||||
c_proj: Linear,
|
c_proj: Linear,
|
||||||
|
span: tracing::Span,
|
||||||
},
|
},
|
||||||
Swiglu {
|
Swiglu {
|
||||||
w1: Linear,
|
w1: Linear,
|
||||||
w3: Linear,
|
w3: Linear,
|
||||||
c_proj: Linear,
|
c_proj: Linear,
|
||||||
|
span: tracing::Span,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -489,7 +497,11 @@ pub mod gpt {
|
|||||||
NonLinearityType::Gelu => {
|
NonLinearityType::Gelu => {
|
||||||
let c_fc = linear_b(cfg.n_embd, hidden_dim, cfg.bias, vb.pp("c_fc"))?;
|
let c_fc = linear_b(cfg.n_embd, hidden_dim, cfg.bias, vb.pp("c_fc"))?;
|
||||||
let c_proj = linear_b(hidden_dim, cfg.n_embd, cfg.bias, vb.pp("c_proj"))?;
|
let c_proj = linear_b(hidden_dim, cfg.n_embd, cfg.bias, vb.pp("c_proj"))?;
|
||||||
Self::Gelu { c_fc, c_proj }
|
Self::Gelu {
|
||||||
|
c_fc,
|
||||||
|
c_proj,
|
||||||
|
span: tracing::span!(tracing::Level::TRACE, "mlp-gelu"),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
NonLinearityType::Swiglu => {
|
NonLinearityType::Swiglu => {
|
||||||
let hidden_dim = (2 * hidden_dim) / 3;
|
let hidden_dim = (2 * hidden_dim) / 3;
|
||||||
@ -502,7 +514,12 @@ pub mod gpt {
|
|||||||
let w1 = linear_b(cfg.n_embd, hidden_dim, cfg.bias, vb.pp("w1"))?;
|
let w1 = linear_b(cfg.n_embd, hidden_dim, cfg.bias, vb.pp("w1"))?;
|
||||||
let w3 = linear_b(cfg.n_embd, hidden_dim, cfg.bias, vb.pp("w3"))?;
|
let w3 = linear_b(cfg.n_embd, hidden_dim, cfg.bias, vb.pp("w3"))?;
|
||||||
let c_proj = linear_b(hidden_dim, cfg.n_embd, cfg.bias, vb.pp("c_proj"))?;
|
let c_proj = linear_b(hidden_dim, cfg.n_embd, cfg.bias, vb.pp("c_proj"))?;
|
||||||
Self::Swiglu { w1, w3, c_proj }
|
Self::Swiglu {
|
||||||
|
w1,
|
||||||
|
w3,
|
||||||
|
c_proj,
|
||||||
|
span: tracing::span!(tracing::Level::TRACE, "mlp-swiglu"),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
Ok(slf)
|
Ok(slf)
|
||||||
@ -512,8 +529,17 @@ pub mod gpt {
|
|||||||
impl Module for MLP {
|
impl Module for MLP {
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
match self {
|
match self {
|
||||||
Self::Gelu { c_fc, c_proj } => xs.apply(c_fc)?.gelu()?.apply(c_proj),
|
Self::Gelu { c_fc, c_proj, span } => {
|
||||||
Self::Swiglu { w1, w3, c_proj } => {
|
let _enter = span.enter();
|
||||||
|
xs.apply(c_fc)?.gelu()?.apply(c_proj)
|
||||||
|
}
|
||||||
|
Self::Swiglu {
|
||||||
|
w1,
|
||||||
|
w3,
|
||||||
|
c_proj,
|
||||||
|
span,
|
||||||
|
} => {
|
||||||
|
let _enter = span.enter();
|
||||||
let w1 = xs.apply(w1)?;
|
let w1 = xs.apply(w1)?;
|
||||||
let w3 = xs.apply(w3)?;
|
let w3 = xs.apply(w3)?;
|
||||||
(w1.silu()? * w3)?.apply(c_proj)
|
(w1.silu()? * w3)?.apply(c_proj)
|
||||||
@ -528,6 +554,7 @@ pub mod gpt {
|
|||||||
ln_2: Norm,
|
ln_2: Norm,
|
||||||
attn: SelfAttention,
|
attn: SelfAttention,
|
||||||
mlp: MLP,
|
mlp: MLP,
|
||||||
|
span: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Block {
|
impl Block {
|
||||||
@ -541,12 +568,14 @@ pub mod gpt {
|
|||||||
ln_2,
|
ln_2,
|
||||||
attn,
|
attn,
|
||||||
mlp,
|
mlp,
|
||||||
|
span: tracing::span!(tracing::Level::TRACE, "gpt-block"),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Module for Block {
|
impl Module for Block {
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
let xs = (xs + xs.apply(&self.ln_1)?.apply(&self.attn))?;
|
let xs = (xs + xs.apply(&self.ln_1)?.apply(&self.attn))?;
|
||||||
let xs = (&xs + xs.apply(&self.ln_2)?.apply(&self.mlp))?;
|
let xs = (&xs + xs.apply(&self.ln_2)?.apply(&self.mlp))?;
|
||||||
Ok(xs)
|
Ok(xs)
|
||||||
@ -563,6 +592,7 @@ pub mod gpt {
|
|||||||
lm_heads: Vec<Linear>,
|
lm_heads: Vec<Linear>,
|
||||||
cfg: Config,
|
cfg: Config,
|
||||||
dtype: DType,
|
dtype: DType,
|
||||||
|
span: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Model {
|
impl Model {
|
||||||
@ -598,6 +628,7 @@ pub mod gpt {
|
|||||||
lm_heads,
|
lm_heads,
|
||||||
cfg,
|
cfg,
|
||||||
dtype: vb.dtype(),
|
dtype: vb.dtype(),
|
||||||
|
span: tracing::span!(tracing::Level::TRACE, "gpt"),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -606,6 +637,7 @@ pub mod gpt {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn forward(&self, idx: &Tensor) -> Result<Vec<Tensor>> {
|
pub fn forward(&self, idx: &Tensor) -> Result<Vec<Tensor>> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
let device = idx.device();
|
let device = idx.device();
|
||||||
let (b, _num_hierarchies, t) = idx.dims3()?;
|
let (b, _num_hierarchies, t) = idx.dims3()?;
|
||||||
let pos = Tensor::arange(0u32, t as u32, device)?;
|
let pos = Tensor::arange(0u32, t as u32, device)?;
|
||||||
@ -689,6 +721,7 @@ pub mod transformer {
|
|||||||
w1: Linear,
|
w1: Linear,
|
||||||
w2: Linear,
|
w2: Linear,
|
||||||
w3: Linear,
|
w3: Linear,
|
||||||
|
span: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl FeedForward {
|
impl FeedForward {
|
||||||
@ -697,12 +730,18 @@ pub mod transformer {
|
|||||||
let w1 = linear_b(cfg.dim, i_size, false, vb.pp("swiglu.w1"))?;
|
let w1 = linear_b(cfg.dim, i_size, false, vb.pp("swiglu.w1"))?;
|
||||||
let w2 = linear_b(i_size, cfg.dim, false, vb.pp("w2"))?;
|
let w2 = linear_b(i_size, cfg.dim, false, vb.pp("w2"))?;
|
||||||
let w3 = linear_b(cfg.dim, i_size, false, vb.pp("swiglu.w3"))?;
|
let w3 = linear_b(cfg.dim, i_size, false, vb.pp("swiglu.w3"))?;
|
||||||
Ok(Self { w1, w2, w3 })
|
Ok(Self {
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
w3,
|
||||||
|
span: tracing::span!(tracing::Level::TRACE, "feed-forward"),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Module for FeedForward {
|
impl Module for FeedForward {
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
let swiglu = (candle_nn::ops::silu(&xs.apply(&self.w1)?)? * xs.apply(&self.w3))?;
|
let swiglu = (candle_nn::ops::silu(&xs.apply(&self.w1)?)? * xs.apply(&self.w3))?;
|
||||||
swiglu.apply(&self.w2)
|
swiglu.apply(&self.w2)
|
||||||
}
|
}
|
||||||
@ -718,6 +757,7 @@ pub mod transformer {
|
|||||||
head_dim: usize,
|
head_dim: usize,
|
||||||
n_head: usize,
|
n_head: usize,
|
||||||
kv_cache: Option<(Tensor, Tensor)>,
|
kv_cache: Option<(Tensor, Tensor)>,
|
||||||
|
span: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Attention {
|
impl Attention {
|
||||||
@ -736,10 +776,12 @@ pub mod transformer {
|
|||||||
head_dim,
|
head_dim,
|
||||||
n_head: cfg.n_head,
|
n_head: cfg.n_head,
|
||||||
kv_cache: None,
|
kv_cache: None,
|
||||||
|
span: tracing::span!(tracing::Level::TRACE, "feed-forward"),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&mut self, xs: &Tensor, _pos: usize, mask: &Tensor) -> Result<Tensor> {
|
fn forward(&mut self, xs: &Tensor, _pos: usize, mask: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
let (b_sz, seqlen, _) = xs.dims3()?;
|
let (b_sz, seqlen, _) = xs.dims3()?;
|
||||||
|
|
||||||
let qkv = xs.apply(&self.wqkv)?;
|
let qkv = xs.apply(&self.wqkv)?;
|
||||||
@ -793,6 +835,7 @@ pub mod transformer {
|
|||||||
feed_forward: FeedForward,
|
feed_forward: FeedForward,
|
||||||
ffn_norm: RmsNorm,
|
ffn_norm: RmsNorm,
|
||||||
attention_norm: RmsNorm,
|
attention_norm: RmsNorm,
|
||||||
|
span: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Block {
|
impl Block {
|
||||||
@ -806,10 +849,12 @@ pub mod transformer {
|
|||||||
feed_forward,
|
feed_forward,
|
||||||
ffn_norm,
|
ffn_norm,
|
||||||
attention_norm,
|
attention_norm,
|
||||||
|
span: tracing::span!(tracing::Level::TRACE, "block"),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&mut self, xs: &Tensor, pos: usize, mask: &Tensor) -> Result<Tensor> {
|
fn forward(&mut self, xs: &Tensor, pos: usize, mask: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
let hs = xs.apply(&self.attention_norm)?;
|
let hs = xs.apply(&self.attention_norm)?;
|
||||||
let hs = (xs + self.attention.forward(&hs, pos, mask))?;
|
let hs = (xs + self.attention.forward(&hs, pos, mask))?;
|
||||||
&hs + hs.apply(&self.ffn_norm)?.apply(&self.feed_forward)
|
&hs + hs.apply(&self.ffn_norm)?.apply(&self.feed_forward)
|
||||||
@ -829,6 +874,7 @@ pub mod transformer {
|
|||||||
norm: RmsNorm,
|
norm: RmsNorm,
|
||||||
output: Linear,
|
output: Linear,
|
||||||
spk_cond_mask: Tensor,
|
spk_cond_mask: Tensor,
|
||||||
|
span: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Model {
|
impl Model {
|
||||||
@ -865,6 +911,7 @@ pub mod transformer {
|
|||||||
norm,
|
norm,
|
||||||
output,
|
output,
|
||||||
spk_cond_mask,
|
spk_cond_mask,
|
||||||
|
span: tracing::span!(tracing::Level::TRACE, "transformer"),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -875,6 +922,7 @@ pub mod transformer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn forward(&mut self, xs: &Tensor, spk_emb: &Tensor, pos: usize) -> Result<Tensor> {
|
pub fn forward(&mut self, xs: &Tensor, spk_emb: &Tensor, pos: usize) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
let (_b_sz, seqlen) = xs.dims2()?;
|
let (_b_sz, seqlen) = xs.dims2()?;
|
||||||
let mask: Vec<_> = (0..seqlen)
|
let mask: Vec<_> = (0..seqlen)
|
||||||
.flat_map(|i| (0..seqlen).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. }))
|
.flat_map(|i| (0..seqlen).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. }))
|
||||||
@ -905,14 +953,19 @@ pub mod adapters {
|
|||||||
// https://github.com/metavoiceio/metavoice-src/blob/9078234c496d76adbec06df789b6b04b1875f129/fam/llm/adapters/tilted_encodec.py
|
// https://github.com/metavoiceio/metavoice-src/blob/9078234c496d76adbec06df789b6b04b1875f129/fam/llm/adapters/tilted_encodec.py
|
||||||
pub struct TiltedEncodec {
|
pub struct TiltedEncodec {
|
||||||
end_of_audio_token: u32,
|
end_of_audio_token: u32,
|
||||||
|
span: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TiltedEncodec {
|
impl TiltedEncodec {
|
||||||
pub fn new(end_of_audio_token: u32) -> Self {
|
pub fn new(end_of_audio_token: u32) -> Self {
|
||||||
Self { end_of_audio_token }
|
Self {
|
||||||
|
end_of_audio_token,
|
||||||
|
span: tracing::span!(tracing::Level::TRACE, "tilted-encodec"),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn decode(&self, tokens: &[Vec<u32>]) -> (Vec<u32>, Vec<Vec<u32>>) {
|
pub fn decode(&self, tokens: &[Vec<u32>]) -> (Vec<u32>, Vec<Vec<u32>>) {
|
||||||
|
let _enter = self.span.enter();
|
||||||
let mut text_ids = vec![];
|
let mut text_ids = vec![];
|
||||||
let mut extracted_audio_ids = vec![];
|
let mut extracted_audio_ids = vec![];
|
||||||
let mut min_audio_ids_len = usize::MAX;
|
let mut min_audio_ids_len = usize::MAX;
|
||||||
@ -941,14 +994,19 @@ pub mod adapters {
|
|||||||
// https://github.com/metavoiceio/metavoice-src/blob/9078234c496d76adbec06df789b6b04b1875f129/fam/llm/adapters/flattened_encodec.py#L4
|
// https://github.com/metavoiceio/metavoice-src/blob/9078234c496d76adbec06df789b6b04b1875f129/fam/llm/adapters/flattened_encodec.py#L4
|
||||||
pub struct FlattenedInterleavedEncodec2Codebook {
|
pub struct FlattenedInterleavedEncodec2Codebook {
|
||||||
end_of_audio_token: u32,
|
end_of_audio_token: u32,
|
||||||
|
span: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl FlattenedInterleavedEncodec2Codebook {
|
impl FlattenedInterleavedEncodec2Codebook {
|
||||||
pub fn new(end_of_audio_token: u32) -> Self {
|
pub fn new(end_of_audio_token: u32) -> Self {
|
||||||
Self { end_of_audio_token }
|
Self {
|
||||||
|
end_of_audio_token,
|
||||||
|
span: tracing::span!(tracing::Level::TRACE, "encodec2codebook"),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn decode(&self, tokens: &[u32]) -> (Vec<u32>, Vec<u32>, Vec<u32>) {
|
pub fn decode(&self, tokens: &[u32]) -> (Vec<u32>, Vec<u32>, Vec<u32>) {
|
||||||
|
let _enter = self.span.enter();
|
||||||
let mut text_ids = vec![];
|
let mut text_ids = vec![];
|
||||||
let mut audio_ids1 = vec![];
|
let mut audio_ids1 = vec![];
|
||||||
let mut audio_ids2 = vec![];
|
let mut audio_ids2 = vec![];
|
||||||
|
@ -14,6 +14,7 @@ pub mod transformer {
|
|||||||
w1: Linear,
|
w1: Linear,
|
||||||
w2: Linear,
|
w2: Linear,
|
||||||
w3: Linear,
|
w3: Linear,
|
||||||
|
span: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl FeedForward {
|
impl FeedForward {
|
||||||
@ -22,12 +23,18 @@ pub mod transformer {
|
|||||||
let w1 = linear_b(cfg.dim, i_size, false, vb.pp("swiglu.w1"))?;
|
let w1 = linear_b(cfg.dim, i_size, false, vb.pp("swiglu.w1"))?;
|
||||||
let w2 = linear_b(i_size, cfg.dim, false, vb.pp("w2"))?;
|
let w2 = linear_b(i_size, cfg.dim, false, vb.pp("w2"))?;
|
||||||
let w3 = linear_b(cfg.dim, i_size, false, vb.pp("swiglu.w3"))?;
|
let w3 = linear_b(cfg.dim, i_size, false, vb.pp("swiglu.w3"))?;
|
||||||
Ok(Self { w1, w2, w3 })
|
Ok(Self {
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
w3,
|
||||||
|
span: tracing::span!(tracing::Level::TRACE, "feed-forward"),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Module for FeedForward {
|
impl Module for FeedForward {
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
let swiglu = (candle_nn::ops::silu(&xs.apply(&self.w1)?)? * xs.apply(&self.w3))?;
|
let swiglu = (candle_nn::ops::silu(&xs.apply(&self.w1)?)? * xs.apply(&self.w3))?;
|
||||||
swiglu.apply(&self.w2)
|
swiglu.apply(&self.w2)
|
||||||
}
|
}
|
||||||
@ -43,6 +50,7 @@ pub mod transformer {
|
|||||||
head_dim: usize,
|
head_dim: usize,
|
||||||
n_head: usize,
|
n_head: usize,
|
||||||
kv_cache: Option<(Tensor, Tensor)>,
|
kv_cache: Option<(Tensor, Tensor)>,
|
||||||
|
span: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Attention {
|
impl Attention {
|
||||||
@ -61,10 +69,12 @@ pub mod transformer {
|
|||||||
head_dim,
|
head_dim,
|
||||||
n_head: cfg.n_head,
|
n_head: cfg.n_head,
|
||||||
kv_cache: None,
|
kv_cache: None,
|
||||||
|
span: tracing::span!(tracing::Level::TRACE, "attention"),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&mut self, xs: &Tensor, _pos: usize, mask: &Tensor) -> Result<Tensor> {
|
fn forward(&mut self, xs: &Tensor, _pos: usize, mask: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
let (b_sz, seqlen, _) = xs.dims3()?;
|
let (b_sz, seqlen, _) = xs.dims3()?;
|
||||||
|
|
||||||
let qkv = xs.apply(&self.wqkv)?;
|
let qkv = xs.apply(&self.wqkv)?;
|
||||||
@ -118,6 +128,7 @@ pub mod transformer {
|
|||||||
feed_forward: FeedForward,
|
feed_forward: FeedForward,
|
||||||
ffn_norm: RmsNorm,
|
ffn_norm: RmsNorm,
|
||||||
attention_norm: RmsNorm,
|
attention_norm: RmsNorm,
|
||||||
|
span: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Block {
|
impl Block {
|
||||||
@ -131,10 +142,12 @@ pub mod transformer {
|
|||||||
feed_forward,
|
feed_forward,
|
||||||
ffn_norm,
|
ffn_norm,
|
||||||
attention_norm,
|
attention_norm,
|
||||||
|
span: tracing::span!(tracing::Level::TRACE, "block"),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&mut self, xs: &Tensor, pos: usize, mask: &Tensor) -> Result<Tensor> {
|
fn forward(&mut self, xs: &Tensor, pos: usize, mask: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
let hs = xs.apply(&self.attention_norm)?;
|
let hs = xs.apply(&self.attention_norm)?;
|
||||||
let hs = (xs + self.attention.forward(&hs, pos, mask))?;
|
let hs = (xs + self.attention.forward(&hs, pos, mask))?;
|
||||||
&hs + hs.apply(&self.ffn_norm)?.apply(&self.feed_forward)
|
&hs + hs.apply(&self.ffn_norm)?.apply(&self.feed_forward)
|
||||||
@ -154,6 +167,7 @@ pub mod transformer {
|
|||||||
norm: RmsNorm,
|
norm: RmsNorm,
|
||||||
output: Linear,
|
output: Linear,
|
||||||
spk_cond_mask: Tensor,
|
spk_cond_mask: Tensor,
|
||||||
|
span: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Model {
|
impl Model {
|
||||||
@ -189,6 +203,7 @@ pub mod transformer {
|
|||||||
norm,
|
norm,
|
||||||
output,
|
output,
|
||||||
spk_cond_mask,
|
spk_cond_mask,
|
||||||
|
span: tracing::span!(tracing::Level::TRACE, "qtransformer"),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -199,6 +214,7 @@ pub mod transformer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn forward(&mut self, xs: &Tensor, spk_emb: &Tensor, pos: usize) -> Result<Tensor> {
|
pub fn forward(&mut self, xs: &Tensor, spk_emb: &Tensor, pos: usize) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
let (_b_sz, seqlen) = xs.dims2()?;
|
let (_b_sz, seqlen) = xs.dims2()?;
|
||||||
let mask: Vec<_> = (0..seqlen)
|
let mask: Vec<_> = (0..seqlen)
|
||||||
.flat_map(|i| (0..seqlen).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. }))
|
.flat_map(|i| (0..seqlen).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. }))
|
||||||
|
Reference in New Issue
Block a user