mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Add tracing to segment-anything (#777)
* Tracing support for segment-anything. * More tracing. * Handle the empty slice case.
This commit is contained in:
@ -14,15 +14,17 @@ pub mod model_sam;
|
||||
pub mod model_transformer;
|
||||
|
||||
use candle::{DType, Result, Tensor};
|
||||
use candle_nn::{Linear, Module, VarBuilder};
|
||||
use candle_nn::{Module, VarBuilder};
|
||||
use clap::Parser;
|
||||
|
||||
pub fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> {
|
||||
if bias {
|
||||
candle_nn::linear(in_dim, out_dim, vb)
|
||||
let inner = if bias {
|
||||
candle_nn::linear(in_dim, out_dim, vb)?
|
||||
} else {
|
||||
candle_nn::linear_no_bias(in_dim, out_dim, vb)
|
||||
}
|
||||
candle_nn::linear_no_bias(in_dim, out_dim, vb)?
|
||||
};
|
||||
let span = tracing::span!(tracing::Level::TRACE, "linear");
|
||||
Ok(Linear { inner, span })
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@ -62,6 +64,7 @@ pub struct MlpBlock {
|
||||
lin1: Linear,
|
||||
lin2: Linear,
|
||||
activation: candle_nn::Activation,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl MlpBlock {
|
||||
@ -71,24 +74,40 @@ impl MlpBlock {
|
||||
activation: candle_nn::Activation,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let lin1 = candle_nn::linear(embedding_dim, mlp_dim, vb.pp("lin1"))?;
|
||||
let lin2 = candle_nn::linear(mlp_dim, embedding_dim, vb.pp("lin2"))?;
|
||||
let lin1 = linear(vb.pp("lin1"), embedding_dim, mlp_dim, true)?;
|
||||
let lin2 = linear(vb.pp("lin2"), mlp_dim, embedding_dim, true)?;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "mlp-block");
|
||||
Ok(Self {
|
||||
lin1,
|
||||
lin2,
|
||||
activation,
|
||||
span,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for MlpBlock {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
xs.apply(&self.lin1)?
|
||||
.apply(&self.activation)?
|
||||
.apply(&self.lin2)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Linear {
|
||||
inner: candle_nn::Linear,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl Module for Linear {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
self.inner.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser)]
|
||||
struct Args {
|
||||
#[arg(long)]
|
||||
@ -109,10 +128,24 @@ struct Args {
|
||||
|
||||
#[arg(long, default_value_t = 0.5)]
|
||||
point_y: f64,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
}
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
|
@ -1,9 +1,10 @@
|
||||
use candle::{DType, IndexOp, Result, Tensor};
|
||||
use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
|
||||
use candle_nn::{layer_norm, LayerNorm, Module, VarBuilder};
|
||||
|
||||
#[derive(Debug)]
|
||||
struct PatchEmbed {
|
||||
proj: candle_nn::Conv2d,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl PatchEmbed {
|
||||
@ -21,23 +22,28 @@ impl PatchEmbed {
|
||||
..Default::default()
|
||||
};
|
||||
let proj = candle_nn::conv2d(in_chans, embed_dim, k_size, cfg, vb.pp("proj"))?;
|
||||
Ok(Self { proj })
|
||||
let span = tracing::span!(tracing::Level::TRACE, "patch-embed");
|
||||
Ok(Self { proj, span })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for PatchEmbed {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
xs.apply(&self.proj)?.permute((0, 2, 3, 1))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Attention {
|
||||
qkv: Linear,
|
||||
proj: Linear,
|
||||
qkv: crate::Linear,
|
||||
proj: crate::Linear,
|
||||
num_heads: usize,
|
||||
scale: f64,
|
||||
rel_pos_hw: Option<(Tensor, Tensor)>,
|
||||
span: tracing::Span,
|
||||
span_rel_pos: tracing::Span,
|
||||
span_softmax: tracing::Span,
|
||||
}
|
||||
|
||||
impl Attention {
|
||||
@ -49,6 +55,9 @@ impl Attention {
|
||||
input_size: (usize, usize),
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "attention");
|
||||
let span_rel_pos = tracing::span!(tracing::Level::TRACE, "attn-rel-pos");
|
||||
let span_softmax = tracing::span!(tracing::Level::TRACE, "attn-sm");
|
||||
let qkv = crate::linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?;
|
||||
let proj = crate::linear(vb.pp("proj"), dim, dim, true)?;
|
||||
let head_dim = dim / num_heads;
|
||||
@ -66,6 +75,9 @@ impl Attention {
|
||||
num_heads,
|
||||
scale,
|
||||
rel_pos_hw,
|
||||
span,
|
||||
span_rel_pos,
|
||||
span_softmax,
|
||||
})
|
||||
}
|
||||
|
||||
@ -126,6 +138,7 @@ fn get_rel_pos(q_size: usize, k_size: usize, rel_pos: &Tensor) -> Result<Tensor>
|
||||
|
||||
impl Module for Attention {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let (b, h, w, c) = xs.dims4()?;
|
||||
let qkv = self
|
||||
.qkv
|
||||
@ -137,8 +150,14 @@ impl Module for Attention {
|
||||
let k = qkv.i(1)?;
|
||||
let v = qkv.i(2)?;
|
||||
let attn = (&q * self.scale)?.matmul(&k.t()?)?;
|
||||
let attn = self.add_decomposed_rel_pos(attn, &q, (h, w), (h, w))?;
|
||||
let attn = candle_nn::ops::softmax_last_dim(&attn)?;
|
||||
let attn = {
|
||||
let _enter = self.span_rel_pos.enter();
|
||||
self.add_decomposed_rel_pos(attn, &q, (h, w), (h, w))?
|
||||
};
|
||||
let attn = {
|
||||
let _enter = self.span_softmax.enter();
|
||||
candle_nn::ops::softmax_last_dim(&attn)?
|
||||
};
|
||||
let attn = attn.matmul(&v)?;
|
||||
let attn = attn
|
||||
.reshape((b, self.num_heads, h, w, c / self.num_heads))?
|
||||
@ -155,6 +174,7 @@ struct Block {
|
||||
norm2: LayerNorm,
|
||||
mlp: crate::MlpBlock,
|
||||
window_size: usize,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl Block {
|
||||
@ -183,12 +203,14 @@ impl Block {
|
||||
vb.pp("attn"),
|
||||
)?;
|
||||
let mlp = crate::MlpBlock::new(dim, dim * 4, candle_nn::Activation::Gelu, vb.pp("mlp"))?;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "ie-block");
|
||||
Ok(Self {
|
||||
norm1,
|
||||
attn,
|
||||
norm2,
|
||||
mlp,
|
||||
window_size,
|
||||
span,
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -249,6 +271,7 @@ fn window_unpartition(
|
||||
|
||||
impl Module for Block {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let shortcut = xs;
|
||||
let xs = self.norm1.forward(xs)?;
|
||||
let hw = (xs.dim(1)?, xs.dim(2)?);
|
||||
@ -277,6 +300,7 @@ pub struct ImageEncoderViT {
|
||||
neck_conv2: candle_nn::Conv2d,
|
||||
neck_ln2: crate::LayerNorm2d,
|
||||
pos_embed: Option<Tensor>,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl ImageEncoderViT {
|
||||
@ -346,6 +370,7 @@ impl ImageEncoderViT {
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let span = tracing::span!(tracing::Level::TRACE, "image-encoder-vit");
|
||||
Ok(Self {
|
||||
patch_embed,
|
||||
blocks,
|
||||
@ -354,12 +379,14 @@ impl ImageEncoderViT {
|
||||
neck_conv2,
|
||||
neck_ln2,
|
||||
pos_embed,
|
||||
span,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for ImageEncoderViT {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let xs = self.patch_embed.forward(xs)?;
|
||||
let mut xs = match &self.pos_embed {
|
||||
Some(pos_embed) => (xs + pos_embed)?,
|
||||
|
@ -1,12 +1,13 @@
|
||||
use candle::{IndexOp, Result, Tensor};
|
||||
use candle_nn::{Linear, Module, VarBuilder};
|
||||
use candle_nn::{Module, VarBuilder};
|
||||
|
||||
use crate::model_transformer::TwoWayTransformer;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct MlpMaskDecoder {
|
||||
layers: Vec<Linear>,
|
||||
layers: Vec<crate::Linear>,
|
||||
sigmoid_output: bool,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl MlpMaskDecoder {
|
||||
@ -30,15 +31,18 @@ impl MlpMaskDecoder {
|
||||
let layer = crate::linear(vb.pp(i), in_dim, out_dim, true)?;
|
||||
layers.push(layer)
|
||||
}
|
||||
let span = tracing::span!(tracing::Level::TRACE, "mlp-mask-decoder");
|
||||
Ok(Self {
|
||||
layers,
|
||||
sigmoid_output,
|
||||
span,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for MlpMaskDecoder {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let mut xs = xs.clone();
|
||||
for (i, layer) in self.layers.iter().enumerate() {
|
||||
xs = layer.forward(&xs)?;
|
||||
@ -65,6 +69,7 @@ pub struct MaskDecoder {
|
||||
num_mask_tokens: usize,
|
||||
output_hypernetworks_mlps: Vec<MlpMaskDecoder>,
|
||||
transformer: TwoWayTransformer,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl MaskDecoder {
|
||||
@ -127,6 +132,7 @@ impl MaskDecoder {
|
||||
/* mlp_dim */ 2048,
|
||||
vb.pp("transformer"),
|
||||
)?;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "mask-decoder");
|
||||
Ok(Self {
|
||||
iou_token,
|
||||
mask_tokens,
|
||||
@ -137,6 +143,7 @@ impl MaskDecoder {
|
||||
num_mask_tokens,
|
||||
output_hypernetworks_mlps,
|
||||
transformer,
|
||||
span,
|
||||
})
|
||||
}
|
||||
|
||||
@ -148,6 +155,7 @@ impl MaskDecoder {
|
||||
dense_prompt_embeddings: &Tensor,
|
||||
multimask_output: bool,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
let _enter = self.span.enter();
|
||||
let (masks, iou_pred) = self.predict_masks(
|
||||
image_embeddings,
|
||||
image_pe,
|
||||
|
@ -64,6 +64,7 @@ pub struct PromptEncoder {
|
||||
image_embedding_size: (usize, usize),
|
||||
input_image_size: (usize, usize),
|
||||
embed_dim: usize,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl PromptEncoder {
|
||||
@ -108,6 +109,7 @@ impl PromptEncoder {
|
||||
let emb = candle_nn::embedding(1, embed_dim, vb_e.pp(i))?;
|
||||
point_embeddings.push(emb)
|
||||
}
|
||||
let span = tracing::span!(tracing::Level::TRACE, "prompt-encoder");
|
||||
Ok(Self {
|
||||
pe_layer,
|
||||
point_embeddings,
|
||||
@ -121,6 +123,7 @@ impl PromptEncoder {
|
||||
image_embedding_size,
|
||||
input_image_size,
|
||||
embed_dim,
|
||||
span,
|
||||
})
|
||||
}
|
||||
|
||||
@ -201,6 +204,7 @@ impl PromptEncoder {
|
||||
boxes: Option<&Tensor>,
|
||||
masks: Option<&Tensor>,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
let _enter = self.span.enter();
|
||||
let se_points = match points {
|
||||
Some((coords, labels)) => Some(self.embed_points(coords, labels, boxes.is_none())?),
|
||||
None => None,
|
||||
|
Reference in New Issue
Block a user