diff --git a/candle-transformers/src/models/quantized_stable_lm.rs b/candle-transformers/src/models/quantized_stable_lm.rs index 304e91ee..d117e4b3 100644 --- a/candle-transformers/src/models/quantized_stable_lm.rs +++ b/candle-transformers/src/models/quantized_stable_lm.rs @@ -14,6 +14,7 @@ struct MLP { up_proj: Linear, down_proj: Linear, act_fn: Activation, + span: tracing::Span, } impl MLP { @@ -28,12 +29,14 @@ impl MLP { up_proj, down_proj, act_fn: cfg.hidden_act, + span: tracing::span!(tracing::Level::TRACE, "mlp"), }) } } impl Module for MLP { fn forward(&self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); let lhs = xs.apply(&self.gate_proj)?.apply(&self.act_fn)?; let rhs = xs.apply(&self.up_proj)?; (lhs * rhs)?.apply(&self.down_proj) @@ -55,6 +58,7 @@ struct Attention { kv_cache: Option<(Tensor, Tensor)>, use_cache: bool, rotary_ndims: usize, + span: tracing::Span, } impl Attention { @@ -81,6 +85,7 @@ impl Attention { kv_cache: None, use_cache: cfg.use_cache, rotary_ndims: cfg.rotary_ndims(), + span: tracing::span!(tracing::Level::TRACE, "attn"), }) } @@ -102,6 +107,7 @@ impl Attention { attention_mask: Option<&Tensor>, seqlen_offset: usize, ) -> Result { + let _enter = self.span.enter(); let (b_sz, q_len, _) = xs.dims3()?; let query_states = self.q_proj.forward(xs)?; @@ -168,6 +174,7 @@ struct DecoderLayer { mlp: MLP, input_layernorm: LayerNorm, post_attention_layernorm: LayerNorm, + span: tracing::Span, } impl DecoderLayer { @@ -185,6 +192,7 @@ impl DecoderLayer { mlp, input_layernorm, post_attention_layernorm, + span: tracing::span!(tracing::Level::TRACE, "layer"), }) } @@ -194,6 +202,7 @@ impl DecoderLayer { attention_mask: Option<&Tensor>, seqlen_offset: usize, ) -> Result { + let _enter = self.span.enter(); let residual = xs; let xs = self.input_layernorm.forward(xs)?; let xs = self.self_attn.forward(&xs, attention_mask, seqlen_offset)?; @@ -211,6 +220,7 @@ pub struct Model { norm: LayerNorm, lm_head: Linear, device: Device, + span: tracing::Span, } impl Model { @@ -233,6 +243,7 @@ impl Model { norm, lm_head, device: vb.device().clone(), + span: tracing::span!(tracing::Level::TRACE, "model"), }) } @@ -258,6 +269,7 @@ impl Model { } pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result { + let _enter = self.span.enter(); let (b_size, seq_len) = input_ids.dims2()?; let attention_mask = if seq_len <= 1 { None diff --git a/candle-transformers/src/models/stable_lm.rs b/candle-transformers/src/models/stable_lm.rs index affb28cf..ef06ea99 100644 --- a/candle-transformers/src/models/stable_lm.rs +++ b/candle-transformers/src/models/stable_lm.rs @@ -109,6 +109,7 @@ struct MLP { up_proj: Linear, down_proj: Linear, act_fn: Activation, + span: tracing::Span, } impl MLP { @@ -123,12 +124,14 @@ impl MLP { up_proj, down_proj, act_fn: cfg.hidden_act, + span: tracing::span!(tracing::Level::TRACE, "mlp"), }) } } impl Module for MLP { fn forward(&self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); let lhs = xs.apply(&self.gate_proj)?.apply(&self.act_fn)?; let rhs = xs.apply(&self.up_proj)?; (lhs * rhs)?.apply(&self.down_proj) @@ -167,6 +170,7 @@ struct Attention { use_cache: bool, rotary_ndims: usize, use_flash_attn: bool, + span: tracing::Span, } impl Attention { @@ -194,6 +198,7 @@ impl Attention { use_cache: cfg.use_cache, rotary_ndims: cfg.rotary_ndims(), use_flash_attn: cfg.use_flash_attn, + span: tracing::span!(tracing::Level::TRACE, "attn"), }) } @@ -215,6 +220,7 @@ impl Attention { attention_mask: Option<&Tensor>, seqlen_offset: usize, ) -> Result { + let _enter = self.span.enter(); let (b_sz, q_len, _) = xs.dims3()?; let query_states = self.q_proj.forward(xs)?; @@ -288,6 +294,7 @@ struct DecoderLayer { mlp: MLP, input_layernorm: LayerNorm, post_attention_layernorm: LayerNorm, + span: tracing::Span, } impl DecoderLayer { @@ -306,6 +313,7 @@ impl DecoderLayer { mlp, input_layernorm, post_attention_layernorm, + span: tracing::span!(tracing::Level::TRACE, "layer"), }) } @@ -315,6 +323,7 @@ impl DecoderLayer { attention_mask: Option<&Tensor>, seqlen_offset: usize, ) -> Result { + let _enter = self.span.enter(); let residual = xs; let xs = self.input_layernorm.forward(xs)?; let xs = self.self_attn.forward(&xs, attention_mask, seqlen_offset)?; @@ -333,6 +342,7 @@ pub struct Model { lm_head: Linear, device: Device, dtype: DType, + span: tracing::Span, } impl Model { @@ -356,6 +366,7 @@ impl Model { lm_head, device: vb.device().clone(), dtype: vb.dtype(), + span: tracing::span!(tracing::Level::TRACE, "model"), }) } @@ -381,6 +392,7 @@ impl Model { } pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result { + let _enter = self.span.enter(); let (b_size, seq_len) = input_ids.dims2()?; let attention_mask = if seq_len <= 1 { None