mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Add tracing to segment-anything (#777)
* Tracing support for segment-anything. * More tracing. * Handle the empty slice case.
This commit is contained in:
@ -64,6 +64,7 @@ pub struct PromptEncoder {
|
||||
image_embedding_size: (usize, usize),
|
||||
input_image_size: (usize, usize),
|
||||
embed_dim: usize,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl PromptEncoder {
|
||||
@ -108,6 +109,7 @@ impl PromptEncoder {
|
||||
let emb = candle_nn::embedding(1, embed_dim, vb_e.pp(i))?;
|
||||
point_embeddings.push(emb)
|
||||
}
|
||||
let span = tracing::span!(tracing::Level::TRACE, "prompt-encoder");
|
||||
Ok(Self {
|
||||
pe_layer,
|
||||
point_embeddings,
|
||||
@ -121,6 +123,7 @@ impl PromptEncoder {
|
||||
image_embedding_size,
|
||||
input_image_size,
|
||||
embed_dim,
|
||||
span,
|
||||
})
|
||||
}
|
||||
|
||||
@ -201,6 +204,7 @@ impl PromptEncoder {
|
||||
boxes: Option<&Tensor>,
|
||||
masks: Option<&Tensor>,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
let _enter = self.span.enter();
|
||||
let se_points = match points {
|
||||
Some((coords, labels)) => Some(self.embed_points(coords, labels, boxes.is_none())?),
|
||||
None => None,
|
||||
|
Reference in New Issue
Block a user