From 158ff3c609b22ed998dea5283738cc1ed13aa592 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Fri, 8 Sep 2023 15:31:29 +0100 Subject: [PATCH] Add tracing to segment-anything (#777) * Tracing support for segment-anything. * More tracing. * Handle the empty slice case. --- candle-core/src/cuda_backend.rs | 3 ++ .../examples/segment-anything/main.rs | 47 ++++++++++++++++--- .../segment-anything/model_image_encoder.rs | 39 ++++++++++++--- .../segment-anything/model_mask_decoder.rs | 12 ++++- .../segment-anything/model_prompt_encoder.rs | 4 ++ candle-examples/examples/whisper/main.rs | 1 - 6 files changed, 90 insertions(+), 16 deletions(-) diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index 2180be5e..cb00441f 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -1889,6 +1889,9 @@ impl BackendStorage for CudaStorage { let src_shape = src_l.shape(); let dims = src_shape.dims(); let el_count = src_shape.elem_count(); + if el_count == 0 { + return Ok(()); + } let cfg = LaunchConfig::for_num_elems(el_count as u32); let dev = &self.device; let ds = dev.htod_copy([dims, src_l.stride()].concat()).w()?; diff --git a/candle-examples/examples/segment-anything/main.rs b/candle-examples/examples/segment-anything/main.rs index c5095c0e..a749ba2a 100644 --- a/candle-examples/examples/segment-anything/main.rs +++ b/candle-examples/examples/segment-anything/main.rs @@ -14,15 +14,17 @@ pub mod model_sam; pub mod model_transformer; use candle::{DType, Result, Tensor}; -use candle_nn::{Linear, Module, VarBuilder}; +use candle_nn::{Module, VarBuilder}; use clap::Parser; pub fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result { - if bias { - candle_nn::linear(in_dim, out_dim, vb) + let inner = if bias { + candle_nn::linear(in_dim, out_dim, vb)? } else { - candle_nn::linear_no_bias(in_dim, out_dim, vb) - } + candle_nn::linear_no_bias(in_dim, out_dim, vb)? + }; + let span = tracing::span!(tracing::Level::TRACE, "linear"); + Ok(Linear { inner, span }) } #[derive(Debug)] @@ -62,6 +64,7 @@ pub struct MlpBlock { lin1: Linear, lin2: Linear, activation: candle_nn::Activation, + span: tracing::Span, } impl MlpBlock { @@ -71,24 +74,40 @@ impl MlpBlock { activation: candle_nn::Activation, vb: VarBuilder, ) -> Result { - let lin1 = candle_nn::linear(embedding_dim, mlp_dim, vb.pp("lin1"))?; - let lin2 = candle_nn::linear(mlp_dim, embedding_dim, vb.pp("lin2"))?; + let lin1 = linear(vb.pp("lin1"), embedding_dim, mlp_dim, true)?; + let lin2 = linear(vb.pp("lin2"), mlp_dim, embedding_dim, true)?; + let span = tracing::span!(tracing::Level::TRACE, "mlp-block"); Ok(Self { lin1, lin2, activation, + span, }) } } impl Module for MlpBlock { fn forward(&self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); xs.apply(&self.lin1)? .apply(&self.activation)? .apply(&self.lin2) } } +#[derive(Debug)] +pub struct Linear { + inner: candle_nn::Linear, + span: tracing::Span, +} + +impl Module for Linear { + fn forward(&self, x: &Tensor) -> Result { + let _enter = self.span.enter(); + self.inner.forward(x) + } +} + #[derive(Parser)] struct Args { #[arg(long)] @@ -109,10 +128,24 @@ struct Args { #[arg(long, default_value_t = 0.5)] point_y: f64, + + /// Enable tracing (generates a trace-timestamp.json file). + #[arg(long)] + tracing: bool, } pub fn main() -> anyhow::Result<()> { + use tracing_chrome::ChromeLayerBuilder; + use tracing_subscriber::prelude::*; + let args = Args::parse(); + let _guard = if args.tracing { + let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); + tracing_subscriber::registry().with(chrome_layer).init(); + Some(guard) + } else { + None + }; let device = candle_examples::device(args.cpu)?; diff --git a/candle-examples/examples/segment-anything/model_image_encoder.rs b/candle-examples/examples/segment-anything/model_image_encoder.rs index f1b76e23..f997170d 100644 --- a/candle-examples/examples/segment-anything/model_image_encoder.rs +++ b/candle-examples/examples/segment-anything/model_image_encoder.rs @@ -1,9 +1,10 @@ use candle::{DType, IndexOp, Result, Tensor}; -use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder}; +use candle_nn::{layer_norm, LayerNorm, Module, VarBuilder}; #[derive(Debug)] struct PatchEmbed { proj: candle_nn::Conv2d, + span: tracing::Span, } impl PatchEmbed { @@ -21,23 +22,28 @@ impl PatchEmbed { ..Default::default() }; let proj = candle_nn::conv2d(in_chans, embed_dim, k_size, cfg, vb.pp("proj"))?; - Ok(Self { proj }) + let span = tracing::span!(tracing::Level::TRACE, "patch-embed"); + Ok(Self { proj, span }) } } impl Module for PatchEmbed { fn forward(&self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); xs.apply(&self.proj)?.permute((0, 2, 3, 1)) } } #[derive(Debug)] struct Attention { - qkv: Linear, - proj: Linear, + qkv: crate::Linear, + proj: crate::Linear, num_heads: usize, scale: f64, rel_pos_hw: Option<(Tensor, Tensor)>, + span: tracing::Span, + span_rel_pos: tracing::Span, + span_softmax: tracing::Span, } impl Attention { @@ -49,6 +55,9 @@ impl Attention { input_size: (usize, usize), vb: VarBuilder, ) -> Result { + let span = tracing::span!(tracing::Level::TRACE, "attention"); + let span_rel_pos = tracing::span!(tracing::Level::TRACE, "attn-rel-pos"); + let span_softmax = tracing::span!(tracing::Level::TRACE, "attn-sm"); let qkv = crate::linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?; let proj = crate::linear(vb.pp("proj"), dim, dim, true)?; let head_dim = dim / num_heads; @@ -66,6 +75,9 @@ impl Attention { num_heads, scale, rel_pos_hw, + span, + span_rel_pos, + span_softmax, }) } @@ -126,6 +138,7 @@ fn get_rel_pos(q_size: usize, k_size: usize, rel_pos: &Tensor) -> Result impl Module for Attention { fn forward(&self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); let (b, h, w, c) = xs.dims4()?; let qkv = self .qkv @@ -137,8 +150,14 @@ impl Module for Attention { let k = qkv.i(1)?; let v = qkv.i(2)?; let attn = (&q * self.scale)?.matmul(&k.t()?)?; - let attn = self.add_decomposed_rel_pos(attn, &q, (h, w), (h, w))?; - let attn = candle_nn::ops::softmax_last_dim(&attn)?; + let attn = { + let _enter = self.span_rel_pos.enter(); + self.add_decomposed_rel_pos(attn, &q, (h, w), (h, w))? + }; + let attn = { + let _enter = self.span_softmax.enter(); + candle_nn::ops::softmax_last_dim(&attn)? + }; let attn = attn.matmul(&v)?; let attn = attn .reshape((b, self.num_heads, h, w, c / self.num_heads))? @@ -155,6 +174,7 @@ struct Block { norm2: LayerNorm, mlp: crate::MlpBlock, window_size: usize, + span: tracing::Span, } impl Block { @@ -183,12 +203,14 @@ impl Block { vb.pp("attn"), )?; let mlp = crate::MlpBlock::new(dim, dim * 4, candle_nn::Activation::Gelu, vb.pp("mlp"))?; + let span = tracing::span!(tracing::Level::TRACE, "ie-block"); Ok(Self { norm1, attn, norm2, mlp, window_size, + span, }) } } @@ -249,6 +271,7 @@ fn window_unpartition( impl Module for Block { fn forward(&self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); let shortcut = xs; let xs = self.norm1.forward(xs)?; let hw = (xs.dim(1)?, xs.dim(2)?); @@ -277,6 +300,7 @@ pub struct ImageEncoderViT { neck_conv2: candle_nn::Conv2d, neck_ln2: crate::LayerNorm2d, pos_embed: Option, + span: tracing::Span, } impl ImageEncoderViT { @@ -346,6 +370,7 @@ impl ImageEncoderViT { } else { None }; + let span = tracing::span!(tracing::Level::TRACE, "image-encoder-vit"); Ok(Self { patch_embed, blocks, @@ -354,12 +379,14 @@ impl ImageEncoderViT { neck_conv2, neck_ln2, pos_embed, + span, }) } } impl Module for ImageEncoderViT { fn forward(&self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); let xs = self.patch_embed.forward(xs)?; let mut xs = match &self.pos_embed { Some(pos_embed) => (xs + pos_embed)?, diff --git a/candle-examples/examples/segment-anything/model_mask_decoder.rs b/candle-examples/examples/segment-anything/model_mask_decoder.rs index 598af1f6..1f6d62a4 100644 --- a/candle-examples/examples/segment-anything/model_mask_decoder.rs +++ b/candle-examples/examples/segment-anything/model_mask_decoder.rs @@ -1,12 +1,13 @@ use candle::{IndexOp, Result, Tensor}; -use candle_nn::{Linear, Module, VarBuilder}; +use candle_nn::{Module, VarBuilder}; use crate::model_transformer::TwoWayTransformer; #[derive(Debug)] struct MlpMaskDecoder { - layers: Vec, + layers: Vec, sigmoid_output: bool, + span: tracing::Span, } impl MlpMaskDecoder { @@ -30,15 +31,18 @@ impl MlpMaskDecoder { let layer = crate::linear(vb.pp(i), in_dim, out_dim, true)?; layers.push(layer) } + let span = tracing::span!(tracing::Level::TRACE, "mlp-mask-decoder"); Ok(Self { layers, sigmoid_output, + span, }) } } impl Module for MlpMaskDecoder { fn forward(&self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); let mut xs = xs.clone(); for (i, layer) in self.layers.iter().enumerate() { xs = layer.forward(&xs)?; @@ -65,6 +69,7 @@ pub struct MaskDecoder { num_mask_tokens: usize, output_hypernetworks_mlps: Vec, transformer: TwoWayTransformer, + span: tracing::Span, } impl MaskDecoder { @@ -127,6 +132,7 @@ impl MaskDecoder { /* mlp_dim */ 2048, vb.pp("transformer"), )?; + let span = tracing::span!(tracing::Level::TRACE, "mask-decoder"); Ok(Self { iou_token, mask_tokens, @@ -137,6 +143,7 @@ impl MaskDecoder { num_mask_tokens, output_hypernetworks_mlps, transformer, + span, }) } @@ -148,6 +155,7 @@ impl MaskDecoder { dense_prompt_embeddings: &Tensor, multimask_output: bool, ) -> Result<(Tensor, Tensor)> { + let _enter = self.span.enter(); let (masks, iou_pred) = self.predict_masks( image_embeddings, image_pe, diff --git a/candle-examples/examples/segment-anything/model_prompt_encoder.rs b/candle-examples/examples/segment-anything/model_prompt_encoder.rs index b401a900..40cc6e36 100644 --- a/candle-examples/examples/segment-anything/model_prompt_encoder.rs +++ b/candle-examples/examples/segment-anything/model_prompt_encoder.rs @@ -64,6 +64,7 @@ pub struct PromptEncoder { image_embedding_size: (usize, usize), input_image_size: (usize, usize), embed_dim: usize, + span: tracing::Span, } impl PromptEncoder { @@ -108,6 +109,7 @@ impl PromptEncoder { let emb = candle_nn::embedding(1, embed_dim, vb_e.pp(i))?; point_embeddings.push(emb) } + let span = tracing::span!(tracing::Level::TRACE, "prompt-encoder"); Ok(Self { pe_layer, point_embeddings, @@ -121,6 +123,7 @@ impl PromptEncoder { image_embedding_size, input_image_size, embed_dim, + span, }) } @@ -201,6 +204,7 @@ impl PromptEncoder { boxes: Option<&Tensor>, masks: Option<&Tensor>, ) -> Result<(Tensor, Tensor)> { + let _enter = self.span.enter(); let se_points = match points { Some((coords, labels)) => Some(self.embed_points(coords, labels, boxes.is_none())?), None => None, diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs index 5dd8ee20..dbe9cc8d 100644 --- a/candle-examples/examples/whisper/main.rs +++ b/candle-examples/examples/whisper/main.rs @@ -431,7 +431,6 @@ fn main() -> Result<()> { let args = Args::parse(); let _guard = if args.tracing { - println!("tracing..."); let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); tracing_subscriber::registry().with(chrome_layer).init(); Some(guard)