diff --git a/candle-transformers/src/models/metavoice.rs b/candle-transformers/src/models/metavoice.rs index 2eeb0713..43de594f 100644 --- a/candle-transformers/src/models/metavoice.rs +++ b/candle-transformers/src/models/metavoice.rs @@ -181,6 +181,7 @@ pub mod tokenizers { pub end_of_text: usize, pub offset: usize, pub ranks: HashMap, Rank>, + span: tracing::Span, } impl BPE { @@ -231,6 +232,7 @@ pub mod tokenizers { end_of_text, offset, ranks, + span: tracing::span!(tracing::Level::TRACE, "bpe"), }) } @@ -310,6 +312,7 @@ pub mod tokenizers { } pub fn encode(&self, text: &str) -> Result> { + let _enter = self.span.enter(); let mut bpe_tokens: Vec = Vec::new(); for word in self.re.find_iter(text) { let word = word.map_err(E::wrap)?; @@ -426,6 +429,7 @@ pub mod gpt { c_attn: Linear, c_proj: Linear, n_head: usize, + span: tracing::Span, } impl SelfAttention { @@ -444,12 +448,14 @@ pub mod gpt { c_attn, c_proj, n_head: cfg.n_head, + span: tracing::span!(tracing::Level::TRACE, "self-attn"), }) } } impl Module for SelfAttention { fn forward(&self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); let (b, t, c) = xs.dims3()?; let c_x = xs .apply(&self.c_attn)? @@ -474,11 +480,13 @@ pub mod gpt { Gelu { c_fc: Linear, c_proj: Linear, + span: tracing::Span, }, Swiglu { w1: Linear, w3: Linear, c_proj: Linear, + span: tracing::Span, }, } @@ -489,7 +497,11 @@ pub mod gpt { NonLinearityType::Gelu => { let c_fc = linear_b(cfg.n_embd, hidden_dim, cfg.bias, vb.pp("c_fc"))?; let c_proj = linear_b(hidden_dim, cfg.n_embd, cfg.bias, vb.pp("c_proj"))?; - Self::Gelu { c_fc, c_proj } + Self::Gelu { + c_fc, + c_proj, + span: tracing::span!(tracing::Level::TRACE, "mlp-gelu"), + } } NonLinearityType::Swiglu => { let hidden_dim = (2 * hidden_dim) / 3; @@ -502,7 +514,12 @@ pub mod gpt { let w1 = linear_b(cfg.n_embd, hidden_dim, cfg.bias, vb.pp("w1"))?; let w3 = linear_b(cfg.n_embd, hidden_dim, cfg.bias, vb.pp("w3"))?; let c_proj = linear_b(hidden_dim, cfg.n_embd, cfg.bias, vb.pp("c_proj"))?; - Self::Swiglu { w1, w3, c_proj } + Self::Swiglu { + w1, + w3, + c_proj, + span: tracing::span!(tracing::Level::TRACE, "mlp-swiglu"), + } } }; Ok(slf) @@ -512,8 +529,17 @@ pub mod gpt { impl Module for MLP { fn forward(&self, xs: &Tensor) -> Result { match self { - Self::Gelu { c_fc, c_proj } => xs.apply(c_fc)?.gelu()?.apply(c_proj), - Self::Swiglu { w1, w3, c_proj } => { + Self::Gelu { c_fc, c_proj, span } => { + let _enter = span.enter(); + xs.apply(c_fc)?.gelu()?.apply(c_proj) + } + Self::Swiglu { + w1, + w3, + c_proj, + span, + } => { + let _enter = span.enter(); let w1 = xs.apply(w1)?; let w3 = xs.apply(w3)?; (w1.silu()? * w3)?.apply(c_proj) @@ -528,6 +554,7 @@ pub mod gpt { ln_2: Norm, attn: SelfAttention, mlp: MLP, + span: tracing::Span, } impl Block { @@ -541,12 +568,14 @@ pub mod gpt { ln_2, attn, mlp, + span: tracing::span!(tracing::Level::TRACE, "gpt-block"), }) } } impl Module for Block { fn forward(&self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); let xs = (xs + xs.apply(&self.ln_1)?.apply(&self.attn))?; let xs = (&xs + xs.apply(&self.ln_2)?.apply(&self.mlp))?; Ok(xs) @@ -563,6 +592,7 @@ pub mod gpt { lm_heads: Vec, cfg: Config, dtype: DType, + span: tracing::Span, } impl Model { @@ -598,6 +628,7 @@ pub mod gpt { lm_heads, cfg, dtype: vb.dtype(), + span: tracing::span!(tracing::Level::TRACE, "gpt"), }) } @@ -606,6 +637,7 @@ pub mod gpt { } pub fn forward(&self, idx: &Tensor) -> Result> { + let _enter = self.span.enter(); let device = idx.device(); let (b, _num_hierarchies, t) = idx.dims3()?; let pos = Tensor::arange(0u32, t as u32, device)?; @@ -689,6 +721,7 @@ pub mod transformer { w1: Linear, w2: Linear, w3: Linear, + span: tracing::Span, } impl FeedForward { @@ -697,12 +730,18 @@ pub mod transformer { let w1 = linear_b(cfg.dim, i_size, false, vb.pp("swiglu.w1"))?; let w2 = linear_b(i_size, cfg.dim, false, vb.pp("w2"))?; let w3 = linear_b(cfg.dim, i_size, false, vb.pp("swiglu.w3"))?; - Ok(Self { w1, w2, w3 }) + Ok(Self { + w1, + w2, + w3, + span: tracing::span!(tracing::Level::TRACE, "feed-forward"), + }) } } impl Module for FeedForward { fn forward(&self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); let swiglu = (candle_nn::ops::silu(&xs.apply(&self.w1)?)? * xs.apply(&self.w3))?; swiglu.apply(&self.w2) } @@ -718,6 +757,7 @@ pub mod transformer { head_dim: usize, n_head: usize, kv_cache: Option<(Tensor, Tensor)>, + span: tracing::Span, } impl Attention { @@ -736,10 +776,12 @@ pub mod transformer { head_dim, n_head: cfg.n_head, kv_cache: None, + span: tracing::span!(tracing::Level::TRACE, "feed-forward"), }) } fn forward(&mut self, xs: &Tensor, _pos: usize, mask: &Tensor) -> Result { + let _enter = self.span.enter(); let (b_sz, seqlen, _) = xs.dims3()?; let qkv = xs.apply(&self.wqkv)?; @@ -793,6 +835,7 @@ pub mod transformer { feed_forward: FeedForward, ffn_norm: RmsNorm, attention_norm: RmsNorm, + span: tracing::Span, } impl Block { @@ -806,10 +849,12 @@ pub mod transformer { feed_forward, ffn_norm, attention_norm, + span: tracing::span!(tracing::Level::TRACE, "block"), }) } fn forward(&mut self, xs: &Tensor, pos: usize, mask: &Tensor) -> Result { + let _enter = self.span.enter(); let hs = xs.apply(&self.attention_norm)?; let hs = (xs + self.attention.forward(&hs, pos, mask))?; &hs + hs.apply(&self.ffn_norm)?.apply(&self.feed_forward) @@ -829,6 +874,7 @@ pub mod transformer { norm: RmsNorm, output: Linear, spk_cond_mask: Tensor, + span: tracing::Span, } impl Model { @@ -865,6 +911,7 @@ pub mod transformer { norm, output, spk_cond_mask, + span: tracing::span!(tracing::Level::TRACE, "transformer"), }) } @@ -875,6 +922,7 @@ pub mod transformer { } pub fn forward(&mut self, xs: &Tensor, spk_emb: &Tensor, pos: usize) -> Result { + let _enter = self.span.enter(); let (_b_sz, seqlen) = xs.dims2()?; let mask: Vec<_> = (0..seqlen) .flat_map(|i| (0..seqlen).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. })) @@ -905,14 +953,19 @@ pub mod adapters { // https://github.com/metavoiceio/metavoice-src/blob/9078234c496d76adbec06df789b6b04b1875f129/fam/llm/adapters/tilted_encodec.py pub struct TiltedEncodec { end_of_audio_token: u32, + span: tracing::Span, } impl TiltedEncodec { pub fn new(end_of_audio_token: u32) -> Self { - Self { end_of_audio_token } + Self { + end_of_audio_token, + span: tracing::span!(tracing::Level::TRACE, "tilted-encodec"), + } } pub fn decode(&self, tokens: &[Vec]) -> (Vec, Vec>) { + let _enter = self.span.enter(); let mut text_ids = vec![]; let mut extracted_audio_ids = vec![]; let mut min_audio_ids_len = usize::MAX; @@ -941,14 +994,19 @@ pub mod adapters { // https://github.com/metavoiceio/metavoice-src/blob/9078234c496d76adbec06df789b6b04b1875f129/fam/llm/adapters/flattened_encodec.py#L4 pub struct FlattenedInterleavedEncodec2Codebook { end_of_audio_token: u32, + span: tracing::Span, } impl FlattenedInterleavedEncodec2Codebook { pub fn new(end_of_audio_token: u32) -> Self { - Self { end_of_audio_token } + Self { + end_of_audio_token, + span: tracing::span!(tracing::Level::TRACE, "encodec2codebook"), + } } pub fn decode(&self, tokens: &[u32]) -> (Vec, Vec, Vec) { + let _enter = self.span.enter(); let mut text_ids = vec![]; let mut audio_ids1 = vec![]; let mut audio_ids2 = vec![]; diff --git a/candle-transformers/src/models/quantized_metavoice.rs b/candle-transformers/src/models/quantized_metavoice.rs index 16545150..84c0388c 100644 --- a/candle-transformers/src/models/quantized_metavoice.rs +++ b/candle-transformers/src/models/quantized_metavoice.rs @@ -14,6 +14,7 @@ pub mod transformer { w1: Linear, w2: Linear, w3: Linear, + span: tracing::Span, } impl FeedForward { @@ -22,12 +23,18 @@ pub mod transformer { let w1 = linear_b(cfg.dim, i_size, false, vb.pp("swiglu.w1"))?; let w2 = linear_b(i_size, cfg.dim, false, vb.pp("w2"))?; let w3 = linear_b(cfg.dim, i_size, false, vb.pp("swiglu.w3"))?; - Ok(Self { w1, w2, w3 }) + Ok(Self { + w1, + w2, + w3, + span: tracing::span!(tracing::Level::TRACE, "feed-forward"), + }) } } impl Module for FeedForward { fn forward(&self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); let swiglu = (candle_nn::ops::silu(&xs.apply(&self.w1)?)? * xs.apply(&self.w3))?; swiglu.apply(&self.w2) } @@ -43,6 +50,7 @@ pub mod transformer { head_dim: usize, n_head: usize, kv_cache: Option<(Tensor, Tensor)>, + span: tracing::Span, } impl Attention { @@ -61,10 +69,12 @@ pub mod transformer { head_dim, n_head: cfg.n_head, kv_cache: None, + span: tracing::span!(tracing::Level::TRACE, "attention"), }) } fn forward(&mut self, xs: &Tensor, _pos: usize, mask: &Tensor) -> Result { + let _enter = self.span.enter(); let (b_sz, seqlen, _) = xs.dims3()?; let qkv = xs.apply(&self.wqkv)?; @@ -118,6 +128,7 @@ pub mod transformer { feed_forward: FeedForward, ffn_norm: RmsNorm, attention_norm: RmsNorm, + span: tracing::Span, } impl Block { @@ -131,10 +142,12 @@ pub mod transformer { feed_forward, ffn_norm, attention_norm, + span: tracing::span!(tracing::Level::TRACE, "block"), }) } fn forward(&mut self, xs: &Tensor, pos: usize, mask: &Tensor) -> Result { + let _enter = self.span.enter(); let hs = xs.apply(&self.attention_norm)?; let hs = (xs + self.attention.forward(&hs, pos, mask))?; &hs + hs.apply(&self.ffn_norm)?.apply(&self.feed_forward) @@ -154,6 +167,7 @@ pub mod transformer { norm: RmsNorm, output: Linear, spk_cond_mask: Tensor, + span: tracing::Span, } impl Model { @@ -189,6 +203,7 @@ pub mod transformer { norm, output, spk_cond_mask, + span: tracing::span!(tracing::Level::TRACE, "qtransformer"), }) } @@ -199,6 +214,7 @@ pub mod transformer { } pub fn forward(&mut self, xs: &Tensor, spk_emb: &Tensor, pos: usize) -> Result { + let _enter = self.span.enter(); let (_b_sz, seqlen) = xs.dims2()?; let mask: Vec<_> = (0..seqlen) .flat_map(|i| (0..seqlen).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. }))