Add tracing to segment-anything (#777)

* Tracing support for segment-anything.

* More tracing.

* Handle the empty slice case.
This commit is contained in:
Laurent Mazare
2023-09-08 15:31:29 +01:00
committed by GitHub
parent e5703d2f56
commit 158ff3c609
6 changed files with 90 additions and 16 deletions

View File

@ -1889,6 +1889,9 @@ impl BackendStorage for CudaStorage {
let src_shape = src_l.shape(); let src_shape = src_l.shape();
let dims = src_shape.dims(); let dims = src_shape.dims();
let el_count = src_shape.elem_count(); let el_count = src_shape.elem_count();
if el_count == 0 {
return Ok(());
}
let cfg = LaunchConfig::for_num_elems(el_count as u32); let cfg = LaunchConfig::for_num_elems(el_count as u32);
let dev = &self.device; let dev = &self.device;
let ds = dev.htod_copy([dims, src_l.stride()].concat()).w()?; let ds = dev.htod_copy([dims, src_l.stride()].concat()).w()?;

View File

@ -14,15 +14,17 @@ pub mod model_sam;
pub mod model_transformer; pub mod model_transformer;
use candle::{DType, Result, Tensor}; use candle::{DType, Result, Tensor};
use candle_nn::{Linear, Module, VarBuilder}; use candle_nn::{Module, VarBuilder};
use clap::Parser; use clap::Parser;
pub fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> { pub fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> {
if bias { let inner = if bias {
candle_nn::linear(in_dim, out_dim, vb) candle_nn::linear(in_dim, out_dim, vb)?
} else { } else {
candle_nn::linear_no_bias(in_dim, out_dim, vb) candle_nn::linear_no_bias(in_dim, out_dim, vb)?
} };
let span = tracing::span!(tracing::Level::TRACE, "linear");
Ok(Linear { inner, span })
} }
#[derive(Debug)] #[derive(Debug)]
@ -62,6 +64,7 @@ pub struct MlpBlock {
lin1: Linear, lin1: Linear,
lin2: Linear, lin2: Linear,
activation: candle_nn::Activation, activation: candle_nn::Activation,
span: tracing::Span,
} }
impl MlpBlock { impl MlpBlock {
@ -71,24 +74,40 @@ impl MlpBlock {
activation: candle_nn::Activation, activation: candle_nn::Activation,
vb: VarBuilder, vb: VarBuilder,
) -> Result<Self> { ) -> Result<Self> {
let lin1 = candle_nn::linear(embedding_dim, mlp_dim, vb.pp("lin1"))?; let lin1 = linear(vb.pp("lin1"), embedding_dim, mlp_dim, true)?;
let lin2 = candle_nn::linear(mlp_dim, embedding_dim, vb.pp("lin2"))?; let lin2 = linear(vb.pp("lin2"), mlp_dim, embedding_dim, true)?;
let span = tracing::span!(tracing::Level::TRACE, "mlp-block");
Ok(Self { Ok(Self {
lin1, lin1,
lin2, lin2,
activation, activation,
span,
}) })
} }
} }
impl Module for MlpBlock { impl Module for MlpBlock {
fn forward(&self, xs: &Tensor) -> Result<Tensor> { fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
xs.apply(&self.lin1)? xs.apply(&self.lin1)?
.apply(&self.activation)? .apply(&self.activation)?
.apply(&self.lin2) .apply(&self.lin2)
} }
} }
#[derive(Debug)]
pub struct Linear {
inner: candle_nn::Linear,
span: tracing::Span,
}
impl Module for Linear {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
self.inner.forward(x)
}
}
#[derive(Parser)] #[derive(Parser)]
struct Args { struct Args {
#[arg(long)] #[arg(long)]
@ -109,10 +128,24 @@ struct Args {
#[arg(long, default_value_t = 0.5)] #[arg(long, default_value_t = 0.5)]
point_y: f64, point_y: f64,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
} }
pub fn main() -> anyhow::Result<()> { pub fn main() -> anyhow::Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse(); let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
let device = candle_examples::device(args.cpu)?; let device = candle_examples::device(args.cpu)?;

View File

@ -1,9 +1,10 @@
use candle::{DType, IndexOp, Result, Tensor}; use candle::{DType, IndexOp, Result, Tensor};
use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder}; use candle_nn::{layer_norm, LayerNorm, Module, VarBuilder};
#[derive(Debug)] #[derive(Debug)]
struct PatchEmbed { struct PatchEmbed {
proj: candle_nn::Conv2d, proj: candle_nn::Conv2d,
span: tracing::Span,
} }
impl PatchEmbed { impl PatchEmbed {
@ -21,23 +22,28 @@ impl PatchEmbed {
..Default::default() ..Default::default()
}; };
let proj = candle_nn::conv2d(in_chans, embed_dim, k_size, cfg, vb.pp("proj"))?; let proj = candle_nn::conv2d(in_chans, embed_dim, k_size, cfg, vb.pp("proj"))?;
Ok(Self { proj }) let span = tracing::span!(tracing::Level::TRACE, "patch-embed");
Ok(Self { proj, span })
} }
} }
impl Module for PatchEmbed { impl Module for PatchEmbed {
fn forward(&self, xs: &Tensor) -> Result<Tensor> { fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
xs.apply(&self.proj)?.permute((0, 2, 3, 1)) xs.apply(&self.proj)?.permute((0, 2, 3, 1))
} }
} }
#[derive(Debug)] #[derive(Debug)]
struct Attention { struct Attention {
qkv: Linear, qkv: crate::Linear,
proj: Linear, proj: crate::Linear,
num_heads: usize, num_heads: usize,
scale: f64, scale: f64,
rel_pos_hw: Option<(Tensor, Tensor)>, rel_pos_hw: Option<(Tensor, Tensor)>,
span: tracing::Span,
span_rel_pos: tracing::Span,
span_softmax: tracing::Span,
} }
impl Attention { impl Attention {
@ -49,6 +55,9 @@ impl Attention {
input_size: (usize, usize), input_size: (usize, usize),
vb: VarBuilder, vb: VarBuilder,
) -> Result<Self> { ) -> Result<Self> {
let span = tracing::span!(tracing::Level::TRACE, "attention");
let span_rel_pos = tracing::span!(tracing::Level::TRACE, "attn-rel-pos");
let span_softmax = tracing::span!(tracing::Level::TRACE, "attn-sm");
let qkv = crate::linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?; let qkv = crate::linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?;
let proj = crate::linear(vb.pp("proj"), dim, dim, true)?; let proj = crate::linear(vb.pp("proj"), dim, dim, true)?;
let head_dim = dim / num_heads; let head_dim = dim / num_heads;
@ -66,6 +75,9 @@ impl Attention {
num_heads, num_heads,
scale, scale,
rel_pos_hw, rel_pos_hw,
span,
span_rel_pos,
span_softmax,
}) })
} }
@ -126,6 +138,7 @@ fn get_rel_pos(q_size: usize, k_size: usize, rel_pos: &Tensor) -> Result<Tensor>
impl Module for Attention { impl Module for Attention {
fn forward(&self, xs: &Tensor) -> Result<Tensor> { fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let (b, h, w, c) = xs.dims4()?; let (b, h, w, c) = xs.dims4()?;
let qkv = self let qkv = self
.qkv .qkv
@ -137,8 +150,14 @@ impl Module for Attention {
let k = qkv.i(1)?; let k = qkv.i(1)?;
let v = qkv.i(2)?; let v = qkv.i(2)?;
let attn = (&q * self.scale)?.matmul(&k.t()?)?; let attn = (&q * self.scale)?.matmul(&k.t()?)?;
let attn = self.add_decomposed_rel_pos(attn, &q, (h, w), (h, w))?; let attn = {
let attn = candle_nn::ops::softmax_last_dim(&attn)?; let _enter = self.span_rel_pos.enter();
self.add_decomposed_rel_pos(attn, &q, (h, w), (h, w))?
};
let attn = {
let _enter = self.span_softmax.enter();
candle_nn::ops::softmax_last_dim(&attn)?
};
let attn = attn.matmul(&v)?; let attn = attn.matmul(&v)?;
let attn = attn let attn = attn
.reshape((b, self.num_heads, h, w, c / self.num_heads))? .reshape((b, self.num_heads, h, w, c / self.num_heads))?
@ -155,6 +174,7 @@ struct Block {
norm2: LayerNorm, norm2: LayerNorm,
mlp: crate::MlpBlock, mlp: crate::MlpBlock,
window_size: usize, window_size: usize,
span: tracing::Span,
} }
impl Block { impl Block {
@ -183,12 +203,14 @@ impl Block {
vb.pp("attn"), vb.pp("attn"),
)?; )?;
let mlp = crate::MlpBlock::new(dim, dim * 4, candle_nn::Activation::Gelu, vb.pp("mlp"))?; let mlp = crate::MlpBlock::new(dim, dim * 4, candle_nn::Activation::Gelu, vb.pp("mlp"))?;
let span = tracing::span!(tracing::Level::TRACE, "ie-block");
Ok(Self { Ok(Self {
norm1, norm1,
attn, attn,
norm2, norm2,
mlp, mlp,
window_size, window_size,
span,
}) })
} }
} }
@ -249,6 +271,7 @@ fn window_unpartition(
impl Module for Block { impl Module for Block {
fn forward(&self, xs: &Tensor) -> Result<Tensor> { fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let shortcut = xs; let shortcut = xs;
let xs = self.norm1.forward(xs)?; let xs = self.norm1.forward(xs)?;
let hw = (xs.dim(1)?, xs.dim(2)?); let hw = (xs.dim(1)?, xs.dim(2)?);
@ -277,6 +300,7 @@ pub struct ImageEncoderViT {
neck_conv2: candle_nn::Conv2d, neck_conv2: candle_nn::Conv2d,
neck_ln2: crate::LayerNorm2d, neck_ln2: crate::LayerNorm2d,
pos_embed: Option<Tensor>, pos_embed: Option<Tensor>,
span: tracing::Span,
} }
impl ImageEncoderViT { impl ImageEncoderViT {
@ -346,6 +370,7 @@ impl ImageEncoderViT {
} else { } else {
None None
}; };
let span = tracing::span!(tracing::Level::TRACE, "image-encoder-vit");
Ok(Self { Ok(Self {
patch_embed, patch_embed,
blocks, blocks,
@ -354,12 +379,14 @@ impl ImageEncoderViT {
neck_conv2, neck_conv2,
neck_ln2, neck_ln2,
pos_embed, pos_embed,
span,
}) })
} }
} }
impl Module for ImageEncoderViT { impl Module for ImageEncoderViT {
fn forward(&self, xs: &Tensor) -> Result<Tensor> { fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let xs = self.patch_embed.forward(xs)?; let xs = self.patch_embed.forward(xs)?;
let mut xs = match &self.pos_embed { let mut xs = match &self.pos_embed {
Some(pos_embed) => (xs + pos_embed)?, Some(pos_embed) => (xs + pos_embed)?,

View File

@ -1,12 +1,13 @@
use candle::{IndexOp, Result, Tensor}; use candle::{IndexOp, Result, Tensor};
use candle_nn::{Linear, Module, VarBuilder}; use candle_nn::{Module, VarBuilder};
use crate::model_transformer::TwoWayTransformer; use crate::model_transformer::TwoWayTransformer;
#[derive(Debug)] #[derive(Debug)]
struct MlpMaskDecoder { struct MlpMaskDecoder {
layers: Vec<Linear>, layers: Vec<crate::Linear>,
sigmoid_output: bool, sigmoid_output: bool,
span: tracing::Span,
} }
impl MlpMaskDecoder { impl MlpMaskDecoder {
@ -30,15 +31,18 @@ impl MlpMaskDecoder {
let layer = crate::linear(vb.pp(i), in_dim, out_dim, true)?; let layer = crate::linear(vb.pp(i), in_dim, out_dim, true)?;
layers.push(layer) layers.push(layer)
} }
let span = tracing::span!(tracing::Level::TRACE, "mlp-mask-decoder");
Ok(Self { Ok(Self {
layers, layers,
sigmoid_output, sigmoid_output,
span,
}) })
} }
} }
impl Module for MlpMaskDecoder { impl Module for MlpMaskDecoder {
fn forward(&self, xs: &Tensor) -> Result<Tensor> { fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let mut xs = xs.clone(); let mut xs = xs.clone();
for (i, layer) in self.layers.iter().enumerate() { for (i, layer) in self.layers.iter().enumerate() {
xs = layer.forward(&xs)?; xs = layer.forward(&xs)?;
@ -65,6 +69,7 @@ pub struct MaskDecoder {
num_mask_tokens: usize, num_mask_tokens: usize,
output_hypernetworks_mlps: Vec<MlpMaskDecoder>, output_hypernetworks_mlps: Vec<MlpMaskDecoder>,
transformer: TwoWayTransformer, transformer: TwoWayTransformer,
span: tracing::Span,
} }
impl MaskDecoder { impl MaskDecoder {
@ -127,6 +132,7 @@ impl MaskDecoder {
/* mlp_dim */ 2048, /* mlp_dim */ 2048,
vb.pp("transformer"), vb.pp("transformer"),
)?; )?;
let span = tracing::span!(tracing::Level::TRACE, "mask-decoder");
Ok(Self { Ok(Self {
iou_token, iou_token,
mask_tokens, mask_tokens,
@ -137,6 +143,7 @@ impl MaskDecoder {
num_mask_tokens, num_mask_tokens,
output_hypernetworks_mlps, output_hypernetworks_mlps,
transformer, transformer,
span,
}) })
} }
@ -148,6 +155,7 @@ impl MaskDecoder {
dense_prompt_embeddings: &Tensor, dense_prompt_embeddings: &Tensor,
multimask_output: bool, multimask_output: bool,
) -> Result<(Tensor, Tensor)> { ) -> Result<(Tensor, Tensor)> {
let _enter = self.span.enter();
let (masks, iou_pred) = self.predict_masks( let (masks, iou_pred) = self.predict_masks(
image_embeddings, image_embeddings,
image_pe, image_pe,

View File

@ -64,6 +64,7 @@ pub struct PromptEncoder {
image_embedding_size: (usize, usize), image_embedding_size: (usize, usize),
input_image_size: (usize, usize), input_image_size: (usize, usize),
embed_dim: usize, embed_dim: usize,
span: tracing::Span,
} }
impl PromptEncoder { impl PromptEncoder {
@ -108,6 +109,7 @@ impl PromptEncoder {
let emb = candle_nn::embedding(1, embed_dim, vb_e.pp(i))?; let emb = candle_nn::embedding(1, embed_dim, vb_e.pp(i))?;
point_embeddings.push(emb) point_embeddings.push(emb)
} }
let span = tracing::span!(tracing::Level::TRACE, "prompt-encoder");
Ok(Self { Ok(Self {
pe_layer, pe_layer,
point_embeddings, point_embeddings,
@ -121,6 +123,7 @@ impl PromptEncoder {
image_embedding_size, image_embedding_size,
input_image_size, input_image_size,
embed_dim, embed_dim,
span,
}) })
} }
@ -201,6 +204,7 @@ impl PromptEncoder {
boxes: Option<&Tensor>, boxes: Option<&Tensor>,
masks: Option<&Tensor>, masks: Option<&Tensor>,
) -> Result<(Tensor, Tensor)> { ) -> Result<(Tensor, Tensor)> {
let _enter = self.span.enter();
let se_points = match points { let se_points = match points {
Some((coords, labels)) => Some(self.embed_points(coords, labels, boxes.is_none())?), Some((coords, labels)) => Some(self.embed_points(coords, labels, boxes.is_none())?),
None => None, None => None,

View File

@ -431,7 +431,6 @@ fn main() -> Result<()> {
let args = Args::parse(); let args = Args::parse();
let _guard = if args.tracing { let _guard = if args.tracing {
println!("tracing...");
let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init(); tracing_subscriber::registry().with(chrome_layer).init();
Some(guard) Some(guard)