diff --git a/candle-examples/examples/segment-anything/model_tiny_vit.rs b/candle-examples/examples/segment-anything/model_tiny_vit.rs index b3941ee1..ff076773 100644 --- a/candle-examples/examples/segment-anything/model_tiny_vit.rs +++ b/candle-examples/examples/segment-anything/model_tiny_vit.rs @@ -13,18 +13,21 @@ const IN_CHANNELS: usize = 3; struct Conv2dBN { c: candle_nn::Conv2d, bn: candle_nn::BatchNorm, + span: tracing::Span, } impl Conv2dBN { fn new(in_: usize, out: usize, ks: usize, cfg: Conv2dConfig, vb: VarBuilder) -> Result { let c = candle_nn::conv2d_no_bias(in_, out, ks, cfg, vb.pp("c"))?; let bn = candle_nn::batch_norm(out, 1e-5, vb.pp("bn"))?; - Ok(Self { c, bn }) + let span = tracing::span!(tracing::Level::TRACE, "conv2d-bn"); + Ok(Self { c, bn, span }) } } impl Module for Conv2dBN { fn forward(&self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); xs.apply(&self.c)?.apply(&self.bn) } } @@ -33,6 +36,7 @@ impl Module for Conv2dBN { struct PatchEmbed { conv1: Conv2dBN, conv2: Conv2dBN, + span: tracing::Span, } impl PatchEmbed { @@ -44,12 +48,14 @@ impl PatchEmbed { }; let conv1 = Conv2dBN::new(in_chans, embed_dim / 2, 3, cfg, vb.pp("seq.0"))?; let conv2 = Conv2dBN::new(embed_dim / 2, embed_dim, 3, cfg, vb.pp("seq.2"))?; - Ok(Self { conv1, conv2 }) + let span = tracing::span!(tracing::Level::TRACE, "patch-embed"); + Ok(Self { conv1, conv2, span }) } } impl Module for PatchEmbed { fn forward(&self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); xs.apply(&self.conv1)?.gelu()?.apply(&self.conv2) } } @@ -59,6 +65,7 @@ struct MBConv { conv1: Conv2dBN, conv2: Conv2dBN, conv3: Conv2dBN, + span: tracing::Span, } impl MBConv { @@ -72,16 +79,19 @@ impl MBConv { let conv1 = Conv2dBN::new(in_, hidden, 1, Default::default(), vb.pp("conv1"))?; let conv2 = Conv2dBN::new(hidden, hidden, 3, cfg2, vb.pp("conv2"))?; let conv3 = Conv2dBN::new(hidden, out, 1, Default::default(), vb.pp("conv3"))?; + let span = tracing::span!(tracing::Level::TRACE, "mb-conv"); Ok(Self { conv1, conv2, conv3, + span, }) } } impl Module for MBConv { fn forward(&self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); let shortcut = xs; let xs = xs .apply(&self.conv1)? @@ -99,6 +109,7 @@ struct PatchMerging { conv2: Conv2dBN, conv3: Conv2dBN, input_resolution: (usize, usize), + span: tracing::Span, } impl PatchMerging { @@ -118,17 +129,20 @@ impl PatchMerging { let conv1 = Conv2dBN::new(dim, out, 1, Default::default(), vb.pp("conv1"))?; let conv2 = Conv2dBN::new(out, out, 3, cfg2, vb.pp("conv2"))?; let conv3 = Conv2dBN::new(out, out, 1, Default::default(), vb.pp("conv3"))?; + let span = tracing::span!(tracing::Level::TRACE, "patch-merging"); Ok(Self { conv1, conv2, conv3, input_resolution, + span, }) } } impl Module for PatchMerging { fn forward(&self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); let xs = if xs.rank() == 3 { let (h, w) = self.input_resolution; let b = xs.dim(0)?; @@ -150,6 +164,7 @@ impl Module for PatchMerging { struct ConvLayer { blocks: Vec, downsample: Option, + span: tracing::Span, } impl ConvLayer { @@ -174,12 +189,18 @@ impl ConvLayer { } else { None }; - Ok(Self { blocks, downsample }) + let span = tracing::span!(tracing::Level::TRACE, "conv-layer"); + Ok(Self { + blocks, + downsample, + span, + }) } } impl Module for ConvLayer { fn forward(&self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); let mut xs = xs.clone(); for block in self.blocks.iter() { xs = block.forward(&xs)? @@ -194,21 +215,29 @@ impl Module for ConvLayer { #[derive(Debug)] struct Mlp { norm: candle_nn::LayerNorm, - fc1: candle_nn::Linear, - fc2: candle_nn::Linear, + fc1: crate::Linear, + fc2: crate::Linear, + span: tracing::Span, } impl Mlp { fn new(in_: usize, hidden: usize, vb: VarBuilder) -> Result { let norm = candle_nn::layer_norm(in_, 1e-5, vb.pp("norm"))?; - let fc1 = candle_nn::linear(in_, hidden, vb.pp("fc1"))?; - let fc2 = candle_nn::linear(hidden, in_, vb.pp("fc2"))?; - Ok(Self { norm, fc1, fc2 }) + let fc1 = crate::linear(vb.pp("fc1"), in_, hidden, true)?; + let fc2 = crate::linear(vb.pp("fc2"), hidden, in_, true)?; + let span = tracing::span!(tracing::Level::TRACE, "mlp"); + Ok(Self { + norm, + fc1, + fc2, + span, + }) } } impl Module for Mlp { fn forward(&self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); xs.apply(&self.norm)? .apply(&self.fc1)? .gelu()? @@ -219,14 +248,17 @@ impl Module for Mlp { #[derive(Debug)] struct Attention { norm: candle_nn::LayerNorm, - qkv: candle_nn::Linear, - proj: candle_nn::Linear, + qkv: crate::Linear, + proj: crate::Linear, ab: Tensor, key_dim: usize, num_heads: usize, d: usize, dh: usize, scale: f64, + span: tracing::Span, + span_matmul: tracing::Span, + span_softmax: tracing::Span, } impl Attention { @@ -243,8 +275,8 @@ impl Attention { let nh_kd = key_dim * num_heads; let h = dh + nh_kd * 2; let norm = candle_nn::layer_norm(dim, 1e-5, vb.pp("norm"))?; - let qkv = candle_nn::linear(dim, h, vb.pp("qkv"))?; - let proj = candle_nn::linear(dh, dim, vb.pp("proj"))?; + let qkv = crate::linear(vb.pp("qkv"), dim, h, true)?; + let proj = crate::linear(vb.pp("proj"), dh, dim, true)?; let points = (0..resolution.0) .flat_map(|x| (0..resolution.1).map(move |y| (x as i64, y as i64))) @@ -265,6 +297,9 @@ impl Attention { attention_biases .index_select(&idxs, 1)? .reshape(((), points.len(), points.len()))?; + let span = tracing::span!(tracing::Level::TRACE, "attention"); + let span_matmul = tracing::span!(tracing::Level::TRACE, "attn-matmul"); + let span_softmax = tracing::span!(tracing::Level::TRACE, "attn-sm"); Ok(Self { norm, qkv, @@ -275,12 +310,16 @@ impl Attention { d, dh, scale: 1f64 / (key_dim as f64).sqrt(), + span, + span_matmul, + span_softmax, }) } } impl Module for Attention { fn forward(&self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); let (b, n, _) = xs.dims3()?; let xs = xs.apply(&self.norm)?; let qkv = xs.apply(&self.qkv)?.reshape((b, n, self.num_heads, ()))?; @@ -296,11 +335,20 @@ impl Module for Attention { .narrow(D::Minus1, 2 * self.key_dim, self.d)? .permute((0, 2, 1, 3))? .contiguous()?; - let attn = (q.matmul(&k.t()?)? * self.scale)?; + let attn = { + let _enter = self.span_matmul.enter(); + (q.matmul(&k.t()?)? * self.scale)? + }; let attn = attn.broadcast_add(&self.ab)?; - let attn = candle_nn::ops::softmax_last_dim(&attn)?; - attn.matmul(&v)? - .transpose(1, 2)? + let attn = { + let _enter = self.span_softmax.enter(); + candle_nn::ops::softmax_last_dim(&attn)? + }; + let attn = { + let _enter = self.span_matmul.enter(); + attn.matmul(&v)? + }; + attn.transpose(1, 2)? .reshape((b, n, self.dh))? .apply(&self.proj) } @@ -313,6 +361,7 @@ struct TinyViTBlock { mlp: Mlp, window_size: usize, input_resolution: (usize, usize), + span: tracing::Span, } impl TinyViTBlock { @@ -339,18 +388,21 @@ impl TinyViTBlock { ..Default::default() }; let local_conv = Conv2dBN::new(dim, dim, LOCAL_CONV_SIZE, cfg, vb.pp("local_conv"))?; + let span = tracing::span!(tracing::Level::TRACE, "attention"); Ok(Self { attn, local_conv, mlp, window_size, input_resolution, + span, }) } } impl Module for TinyViTBlock { fn forward(&self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); let (h, w) = self.input_resolution; let (b, l, c) = xs.dims3()?; let res_x = xs; @@ -410,6 +462,7 @@ impl Module for TinyViTBlock { struct BasicLayer { blocks: Vec, downsample: Option, + span: tracing::Span, } impl BasicLayer { @@ -442,12 +495,18 @@ impl BasicLayer { } else { None }; - Ok(Self { blocks, downsample }) + let span = tracing::span!(tracing::Level::TRACE, "basic-layer"); + Ok(Self { + blocks, + downsample, + span, + }) } } impl Module for BasicLayer { fn forward(&self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); let mut xs = xs.clone(); for block in self.blocks.iter() { xs = block.forward(&xs)? @@ -470,6 +529,8 @@ pub struct TinyViT { neck_ln1: crate::LayerNorm2d, neck_conv2: candle_nn::Conv2d, neck_ln2: crate::LayerNorm2d, + span: tracing::Span, + span_neck: tracing::Span, } impl TinyViT { @@ -525,6 +586,8 @@ impl TinyViT { let neck_conv2 = candle_nn::conv2d_no_bias(256, 256, 3, cfg, vb.pp("neck.2"))?; let neck_ln2 = crate::LayerNorm2d::new(256, 1e-6, vb.pp("neck.3"))?; + let span = tracing::span!(tracing::Level::TRACE, "tiny-vit"); + let span_neck = tracing::span!(tracing::Level::TRACE, "neck"); Ok(Self { patch_embed, layer0, @@ -533,18 +596,22 @@ impl TinyViT { neck_ln1, neck_conv2, neck_ln2, + span, + span_neck, }) } } impl Module for TinyViT { fn forward(&self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); let xs = self.patch_embed.forward(xs)?; let mut xs = self.layer0.forward(&xs)?; for layer in self.layers.iter() { xs = layer.forward(&xs)? } let (b, _, c) = xs.dims3()?; + let _enter = self.span_neck.enter(); xs.reshape((b, 64, 64, c))? .permute((0, 3, 1, 2))? .apply(&self.neck_conv1)?