Add tracing to segment-anything (#777)

* Tracing support for segment-anything.

* More tracing.

* Handle the empty slice case.
This commit is contained in:
Laurent Mazare
2023-09-08 15:31:29 +01:00
committed by GitHub
parent e5703d2f56
commit 158ff3c609
6 changed files with 90 additions and 16 deletions

View File

@ -64,6 +64,7 @@ pub struct PromptEncoder {
image_embedding_size: (usize, usize),
input_image_size: (usize, usize),
embed_dim: usize,
span: tracing::Span,
}
impl PromptEncoder {
@ -108,6 +109,7 @@ impl PromptEncoder {
let emb = candle_nn::embedding(1, embed_dim, vb_e.pp(i))?;
point_embeddings.push(emb)
}
let span = tracing::span!(tracing::Level::TRACE, "prompt-encoder");
Ok(Self {
pe_layer,
point_embeddings,
@ -121,6 +123,7 @@ impl PromptEncoder {
image_embedding_size,
input_image_size,
embed_dim,
span,
})
}
@ -201,6 +204,7 @@ impl PromptEncoder {
boxes: Option<&Tensor>,
masks: Option<&Tensor>,
) -> Result<(Tensor, Tensor)> {
let _enter = self.span.enter();
let se_points = match points {
Some((coords, labels)) => Some(self.embed_points(coords, labels, boxes.is_none())?),
None => None,