Add some tracing to metavoice. (#1826)

This commit is contained in:
Laurent Mazare
2024-03-09 12:24:11 +01:00
committed by GitHub
parent 56c9d3ee7b
commit 0c5eecbc0f
2 changed files with 82 additions and 8 deletions

View File

@ -181,6 +181,7 @@ pub mod tokenizers {
pub end_of_text: usize, pub end_of_text: usize,
pub offset: usize, pub offset: usize,
pub ranks: HashMap<Vec<u8>, Rank>, pub ranks: HashMap<Vec<u8>, Rank>,
span: tracing::Span,
} }
impl BPE { impl BPE {
@ -231,6 +232,7 @@ pub mod tokenizers {
end_of_text, end_of_text,
offset, offset,
ranks, ranks,
span: tracing::span!(tracing::Level::TRACE, "bpe"),
}) })
} }
@ -310,6 +312,7 @@ pub mod tokenizers {
} }
pub fn encode(&self, text: &str) -> Result<Vec<u32>> { pub fn encode(&self, text: &str) -> Result<Vec<u32>> {
let _enter = self.span.enter();
let mut bpe_tokens: Vec<u32> = Vec::new(); let mut bpe_tokens: Vec<u32> = Vec::new();
for word in self.re.find_iter(text) { for word in self.re.find_iter(text) {
let word = word.map_err(E::wrap)?; let word = word.map_err(E::wrap)?;
@ -426,6 +429,7 @@ pub mod gpt {
c_attn: Linear, c_attn: Linear,
c_proj: Linear, c_proj: Linear,
n_head: usize, n_head: usize,
span: tracing::Span,
} }
impl SelfAttention { impl SelfAttention {
@ -444,12 +448,14 @@ pub mod gpt {
c_attn, c_attn,
c_proj, c_proj,
n_head: cfg.n_head, n_head: cfg.n_head,
span: tracing::span!(tracing::Level::TRACE, "self-attn"),
}) })
} }
} }
impl Module for SelfAttention { impl Module for SelfAttention {
fn forward(&self, xs: &Tensor) -> Result<Tensor> { fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let (b, t, c) = xs.dims3()?; let (b, t, c) = xs.dims3()?;
let c_x = xs let c_x = xs
.apply(&self.c_attn)? .apply(&self.c_attn)?
@ -474,11 +480,13 @@ pub mod gpt {
Gelu { Gelu {
c_fc: Linear, c_fc: Linear,
c_proj: Linear, c_proj: Linear,
span: tracing::Span,
}, },
Swiglu { Swiglu {
w1: Linear, w1: Linear,
w3: Linear, w3: Linear,
c_proj: Linear, c_proj: Linear,
span: tracing::Span,
}, },
} }
@ -489,7 +497,11 @@ pub mod gpt {
NonLinearityType::Gelu => { NonLinearityType::Gelu => {
let c_fc = linear_b(cfg.n_embd, hidden_dim, cfg.bias, vb.pp("c_fc"))?; let c_fc = linear_b(cfg.n_embd, hidden_dim, cfg.bias, vb.pp("c_fc"))?;
let c_proj = linear_b(hidden_dim, cfg.n_embd, cfg.bias, vb.pp("c_proj"))?; let c_proj = linear_b(hidden_dim, cfg.n_embd, cfg.bias, vb.pp("c_proj"))?;
Self::Gelu { c_fc, c_proj } Self::Gelu {
c_fc,
c_proj,
span: tracing::span!(tracing::Level::TRACE, "mlp-gelu"),
}
} }
NonLinearityType::Swiglu => { NonLinearityType::Swiglu => {
let hidden_dim = (2 * hidden_dim) / 3; let hidden_dim = (2 * hidden_dim) / 3;
@ -502,7 +514,12 @@ pub mod gpt {
let w1 = linear_b(cfg.n_embd, hidden_dim, cfg.bias, vb.pp("w1"))?; let w1 = linear_b(cfg.n_embd, hidden_dim, cfg.bias, vb.pp("w1"))?;
let w3 = linear_b(cfg.n_embd, hidden_dim, cfg.bias, vb.pp("w3"))?; let w3 = linear_b(cfg.n_embd, hidden_dim, cfg.bias, vb.pp("w3"))?;
let c_proj = linear_b(hidden_dim, cfg.n_embd, cfg.bias, vb.pp("c_proj"))?; let c_proj = linear_b(hidden_dim, cfg.n_embd, cfg.bias, vb.pp("c_proj"))?;
Self::Swiglu { w1, w3, c_proj } Self::Swiglu {
w1,
w3,
c_proj,
span: tracing::span!(tracing::Level::TRACE, "mlp-swiglu"),
}
} }
}; };
Ok(slf) Ok(slf)
@ -512,8 +529,17 @@ pub mod gpt {
impl Module for MLP { impl Module for MLP {
fn forward(&self, xs: &Tensor) -> Result<Tensor> { fn forward(&self, xs: &Tensor) -> Result<Tensor> {
match self { match self {
Self::Gelu { c_fc, c_proj } => xs.apply(c_fc)?.gelu()?.apply(c_proj), Self::Gelu { c_fc, c_proj, span } => {
Self::Swiglu { w1, w3, c_proj } => { let _enter = span.enter();
xs.apply(c_fc)?.gelu()?.apply(c_proj)
}
Self::Swiglu {
w1,
w3,
c_proj,
span,
} => {
let _enter = span.enter();
let w1 = xs.apply(w1)?; let w1 = xs.apply(w1)?;
let w3 = xs.apply(w3)?; let w3 = xs.apply(w3)?;
(w1.silu()? * w3)?.apply(c_proj) (w1.silu()? * w3)?.apply(c_proj)
@ -528,6 +554,7 @@ pub mod gpt {
ln_2: Norm, ln_2: Norm,
attn: SelfAttention, attn: SelfAttention,
mlp: MLP, mlp: MLP,
span: tracing::Span,
} }
impl Block { impl Block {
@ -541,12 +568,14 @@ pub mod gpt {
ln_2, ln_2,
attn, attn,
mlp, mlp,
span: tracing::span!(tracing::Level::TRACE, "gpt-block"),
}) })
} }
} }
impl Module for Block { impl Module for Block {
fn forward(&self, xs: &Tensor) -> Result<Tensor> { fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let xs = (xs + xs.apply(&self.ln_1)?.apply(&self.attn))?; let xs = (xs + xs.apply(&self.ln_1)?.apply(&self.attn))?;
let xs = (&xs + xs.apply(&self.ln_2)?.apply(&self.mlp))?; let xs = (&xs + xs.apply(&self.ln_2)?.apply(&self.mlp))?;
Ok(xs) Ok(xs)
@ -563,6 +592,7 @@ pub mod gpt {
lm_heads: Vec<Linear>, lm_heads: Vec<Linear>,
cfg: Config, cfg: Config,
dtype: DType, dtype: DType,
span: tracing::Span,
} }
impl Model { impl Model {
@ -598,6 +628,7 @@ pub mod gpt {
lm_heads, lm_heads,
cfg, cfg,
dtype: vb.dtype(), dtype: vb.dtype(),
span: tracing::span!(tracing::Level::TRACE, "gpt"),
}) })
} }
@ -606,6 +637,7 @@ pub mod gpt {
} }
pub fn forward(&self, idx: &Tensor) -> Result<Vec<Tensor>> { pub fn forward(&self, idx: &Tensor) -> Result<Vec<Tensor>> {
let _enter = self.span.enter();
let device = idx.device(); let device = idx.device();
let (b, _num_hierarchies, t) = idx.dims3()?; let (b, _num_hierarchies, t) = idx.dims3()?;
let pos = Tensor::arange(0u32, t as u32, device)?; let pos = Tensor::arange(0u32, t as u32, device)?;
@ -689,6 +721,7 @@ pub mod transformer {
w1: Linear, w1: Linear,
w2: Linear, w2: Linear,
w3: Linear, w3: Linear,
span: tracing::Span,
} }
impl FeedForward { impl FeedForward {
@ -697,12 +730,18 @@ pub mod transformer {
let w1 = linear_b(cfg.dim, i_size, false, vb.pp("swiglu.w1"))?; let w1 = linear_b(cfg.dim, i_size, false, vb.pp("swiglu.w1"))?;
let w2 = linear_b(i_size, cfg.dim, false, vb.pp("w2"))?; let w2 = linear_b(i_size, cfg.dim, false, vb.pp("w2"))?;
let w3 = linear_b(cfg.dim, i_size, false, vb.pp("swiglu.w3"))?; let w3 = linear_b(cfg.dim, i_size, false, vb.pp("swiglu.w3"))?;
Ok(Self { w1, w2, w3 }) Ok(Self {
w1,
w2,
w3,
span: tracing::span!(tracing::Level::TRACE, "feed-forward"),
})
} }
} }
impl Module for FeedForward { impl Module for FeedForward {
fn forward(&self, xs: &Tensor) -> Result<Tensor> { fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let swiglu = (candle_nn::ops::silu(&xs.apply(&self.w1)?)? * xs.apply(&self.w3))?; let swiglu = (candle_nn::ops::silu(&xs.apply(&self.w1)?)? * xs.apply(&self.w3))?;
swiglu.apply(&self.w2) swiglu.apply(&self.w2)
} }
@ -718,6 +757,7 @@ pub mod transformer {
head_dim: usize, head_dim: usize,
n_head: usize, n_head: usize,
kv_cache: Option<(Tensor, Tensor)>, kv_cache: Option<(Tensor, Tensor)>,
span: tracing::Span,
} }
impl Attention { impl Attention {
@ -736,10 +776,12 @@ pub mod transformer {
head_dim, head_dim,
n_head: cfg.n_head, n_head: cfg.n_head,
kv_cache: None, kv_cache: None,
span: tracing::span!(tracing::Level::TRACE, "feed-forward"),
}) })
} }
fn forward(&mut self, xs: &Tensor, _pos: usize, mask: &Tensor) -> Result<Tensor> { fn forward(&mut self, xs: &Tensor, _pos: usize, mask: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let (b_sz, seqlen, _) = xs.dims3()?; let (b_sz, seqlen, _) = xs.dims3()?;
let qkv = xs.apply(&self.wqkv)?; let qkv = xs.apply(&self.wqkv)?;
@ -793,6 +835,7 @@ pub mod transformer {
feed_forward: FeedForward, feed_forward: FeedForward,
ffn_norm: RmsNorm, ffn_norm: RmsNorm,
attention_norm: RmsNorm, attention_norm: RmsNorm,
span: tracing::Span,
} }
impl Block { impl Block {
@ -806,10 +849,12 @@ pub mod transformer {
feed_forward, feed_forward,
ffn_norm, ffn_norm,
attention_norm, attention_norm,
span: tracing::span!(tracing::Level::TRACE, "block"),
}) })
} }
fn forward(&mut self, xs: &Tensor, pos: usize, mask: &Tensor) -> Result<Tensor> { fn forward(&mut self, xs: &Tensor, pos: usize, mask: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let hs = xs.apply(&self.attention_norm)?; let hs = xs.apply(&self.attention_norm)?;
let hs = (xs + self.attention.forward(&hs, pos, mask))?; let hs = (xs + self.attention.forward(&hs, pos, mask))?;
&hs + hs.apply(&self.ffn_norm)?.apply(&self.feed_forward) &hs + hs.apply(&self.ffn_norm)?.apply(&self.feed_forward)
@ -829,6 +874,7 @@ pub mod transformer {
norm: RmsNorm, norm: RmsNorm,
output: Linear, output: Linear,
spk_cond_mask: Tensor, spk_cond_mask: Tensor,
span: tracing::Span,
} }
impl Model { impl Model {
@ -865,6 +911,7 @@ pub mod transformer {
norm, norm,
output, output,
spk_cond_mask, spk_cond_mask,
span: tracing::span!(tracing::Level::TRACE, "transformer"),
}) })
} }
@ -875,6 +922,7 @@ pub mod transformer {
} }
pub fn forward(&mut self, xs: &Tensor, spk_emb: &Tensor, pos: usize) -> Result<Tensor> { pub fn forward(&mut self, xs: &Tensor, spk_emb: &Tensor, pos: usize) -> Result<Tensor> {
let _enter = self.span.enter();
let (_b_sz, seqlen) = xs.dims2()?; let (_b_sz, seqlen) = xs.dims2()?;
let mask: Vec<_> = (0..seqlen) let mask: Vec<_> = (0..seqlen)
.flat_map(|i| (0..seqlen).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. })) .flat_map(|i| (0..seqlen).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. }))
@ -905,14 +953,19 @@ pub mod adapters {
// https://github.com/metavoiceio/metavoice-src/blob/9078234c496d76adbec06df789b6b04b1875f129/fam/llm/adapters/tilted_encodec.py // https://github.com/metavoiceio/metavoice-src/blob/9078234c496d76adbec06df789b6b04b1875f129/fam/llm/adapters/tilted_encodec.py
pub struct TiltedEncodec { pub struct TiltedEncodec {
end_of_audio_token: u32, end_of_audio_token: u32,
span: tracing::Span,
} }
impl TiltedEncodec { impl TiltedEncodec {
pub fn new(end_of_audio_token: u32) -> Self { pub fn new(end_of_audio_token: u32) -> Self {
Self { end_of_audio_token } Self {
end_of_audio_token,
span: tracing::span!(tracing::Level::TRACE, "tilted-encodec"),
}
} }
pub fn decode(&self, tokens: &[Vec<u32>]) -> (Vec<u32>, Vec<Vec<u32>>) { pub fn decode(&self, tokens: &[Vec<u32>]) -> (Vec<u32>, Vec<Vec<u32>>) {
let _enter = self.span.enter();
let mut text_ids = vec![]; let mut text_ids = vec![];
let mut extracted_audio_ids = vec![]; let mut extracted_audio_ids = vec![];
let mut min_audio_ids_len = usize::MAX; let mut min_audio_ids_len = usize::MAX;
@ -941,14 +994,19 @@ pub mod adapters {
// https://github.com/metavoiceio/metavoice-src/blob/9078234c496d76adbec06df789b6b04b1875f129/fam/llm/adapters/flattened_encodec.py#L4 // https://github.com/metavoiceio/metavoice-src/blob/9078234c496d76adbec06df789b6b04b1875f129/fam/llm/adapters/flattened_encodec.py#L4
pub struct FlattenedInterleavedEncodec2Codebook { pub struct FlattenedInterleavedEncodec2Codebook {
end_of_audio_token: u32, end_of_audio_token: u32,
span: tracing::Span,
} }
impl FlattenedInterleavedEncodec2Codebook { impl FlattenedInterleavedEncodec2Codebook {
pub fn new(end_of_audio_token: u32) -> Self { pub fn new(end_of_audio_token: u32) -> Self {
Self { end_of_audio_token } Self {
end_of_audio_token,
span: tracing::span!(tracing::Level::TRACE, "encodec2codebook"),
}
} }
pub fn decode(&self, tokens: &[u32]) -> (Vec<u32>, Vec<u32>, Vec<u32>) { pub fn decode(&self, tokens: &[u32]) -> (Vec<u32>, Vec<u32>, Vec<u32>) {
let _enter = self.span.enter();
let mut text_ids = vec![]; let mut text_ids = vec![];
let mut audio_ids1 = vec![]; let mut audio_ids1 = vec![];
let mut audio_ids2 = vec![]; let mut audio_ids2 = vec![];

View File

@ -14,6 +14,7 @@ pub mod transformer {
w1: Linear, w1: Linear,
w2: Linear, w2: Linear,
w3: Linear, w3: Linear,
span: tracing::Span,
} }
impl FeedForward { impl FeedForward {
@ -22,12 +23,18 @@ pub mod transformer {
let w1 = linear_b(cfg.dim, i_size, false, vb.pp("swiglu.w1"))?; let w1 = linear_b(cfg.dim, i_size, false, vb.pp("swiglu.w1"))?;
let w2 = linear_b(i_size, cfg.dim, false, vb.pp("w2"))?; let w2 = linear_b(i_size, cfg.dim, false, vb.pp("w2"))?;
let w3 = linear_b(cfg.dim, i_size, false, vb.pp("swiglu.w3"))?; let w3 = linear_b(cfg.dim, i_size, false, vb.pp("swiglu.w3"))?;
Ok(Self { w1, w2, w3 }) Ok(Self {
w1,
w2,
w3,
span: tracing::span!(tracing::Level::TRACE, "feed-forward"),
})
} }
} }
impl Module for FeedForward { impl Module for FeedForward {
fn forward(&self, xs: &Tensor) -> Result<Tensor> { fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let swiglu = (candle_nn::ops::silu(&xs.apply(&self.w1)?)? * xs.apply(&self.w3))?; let swiglu = (candle_nn::ops::silu(&xs.apply(&self.w1)?)? * xs.apply(&self.w3))?;
swiglu.apply(&self.w2) swiglu.apply(&self.w2)
} }
@ -43,6 +50,7 @@ pub mod transformer {
head_dim: usize, head_dim: usize,
n_head: usize, n_head: usize,
kv_cache: Option<(Tensor, Tensor)>, kv_cache: Option<(Tensor, Tensor)>,
span: tracing::Span,
} }
impl Attention { impl Attention {
@ -61,10 +69,12 @@ pub mod transformer {
head_dim, head_dim,
n_head: cfg.n_head, n_head: cfg.n_head,
kv_cache: None, kv_cache: None,
span: tracing::span!(tracing::Level::TRACE, "attention"),
}) })
} }
fn forward(&mut self, xs: &Tensor, _pos: usize, mask: &Tensor) -> Result<Tensor> { fn forward(&mut self, xs: &Tensor, _pos: usize, mask: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let (b_sz, seqlen, _) = xs.dims3()?; let (b_sz, seqlen, _) = xs.dims3()?;
let qkv = xs.apply(&self.wqkv)?; let qkv = xs.apply(&self.wqkv)?;
@ -118,6 +128,7 @@ pub mod transformer {
feed_forward: FeedForward, feed_forward: FeedForward,
ffn_norm: RmsNorm, ffn_norm: RmsNorm,
attention_norm: RmsNorm, attention_norm: RmsNorm,
span: tracing::Span,
} }
impl Block { impl Block {
@ -131,10 +142,12 @@ pub mod transformer {
feed_forward, feed_forward,
ffn_norm, ffn_norm,
attention_norm, attention_norm,
span: tracing::span!(tracing::Level::TRACE, "block"),
}) })
} }
fn forward(&mut self, xs: &Tensor, pos: usize, mask: &Tensor) -> Result<Tensor> { fn forward(&mut self, xs: &Tensor, pos: usize, mask: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let hs = xs.apply(&self.attention_norm)?; let hs = xs.apply(&self.attention_norm)?;
let hs = (xs + self.attention.forward(&hs, pos, mask))?; let hs = (xs + self.attention.forward(&hs, pos, mask))?;
&hs + hs.apply(&self.ffn_norm)?.apply(&self.feed_forward) &hs + hs.apply(&self.ffn_norm)?.apply(&self.feed_forward)
@ -154,6 +167,7 @@ pub mod transformer {
norm: RmsNorm, norm: RmsNorm,
output: Linear, output: Linear,
spk_cond_mask: Tensor, spk_cond_mask: Tensor,
span: tracing::Span,
} }
impl Model { impl Model {
@ -189,6 +203,7 @@ pub mod transformer {
norm, norm,
output, output,
spk_cond_mask, spk_cond_mask,
span: tracing::span!(tracing::Level::TRACE, "qtransformer"),
}) })
} }
@ -199,6 +214,7 @@ pub mod transformer {
} }
pub fn forward(&mut self, xs: &Tensor, spk_emb: &Tensor, pos: usize) -> Result<Tensor> { pub fn forward(&mut self, xs: &Tensor, spk_emb: &Tensor, pos: usize) -> Result<Tensor> {
let _enter = self.span.enter();
let (_b_sz, seqlen) = xs.dims2()?; let (_b_sz, seqlen) = xs.dims2()?;
let mask: Vec<_> = (0..seqlen) let mask: Vec<_> = (0..seqlen)
.flat_map(|i| (0..seqlen).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. })) .flat_map(|i| (0..seqlen).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. }))