From c1453f00b11c9dd12c5aa81fb4355ce47d22d477 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Fri, 8 Sep 2023 09:39:10 +0100 Subject: [PATCH] Improve the safetensor loading in the segment-anything example. (#772) * Improve the safetensor loading in the segment-anything example. * Properly handle the labels when embedding the point prompts. --- .../examples/segment-anything/main.rs | 7 +++++- .../segment-anything/model_prompt_encoder.rs | 23 ++++++++++++++++++- 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/candle-examples/examples/segment-anything/main.rs b/candle-examples/examples/segment-anything/main.rs index 89d5b56c..c53c1010 100644 --- a/candle-examples/examples/segment-anything/main.rs +++ b/candle-examples/examples/segment-anything/main.rs @@ -110,7 +110,7 @@ pub fn main() -> anyhow::Result<()> { let image = if args.image.ends_with(".safetensors") { let mut tensors = candle::safetensors::load(&args.image, &device)?; - match tensors.remove("image") { + let image = match tensors.remove("image") { Some(image) => image, None => { if tensors.len() != 1 { @@ -118,6 +118,11 @@ pub fn main() -> anyhow::Result<()> { } tensors.into_values().next().unwrap() } + }; + if image.rank() == 4 { + image.get(0)? + } else { + image } } else { candle_examples::load_image(args.image, Some(model_sam::IMAGE_SIZE))?.to_device(&device)? diff --git a/candle-examples/examples/segment-anything/model_prompt_encoder.rs b/candle-examples/examples/segment-anything/model_prompt_encoder.rs index aab0c4fd..e4291ebb 100644 --- a/candle-examples/examples/segment-anything/model_prompt_encoder.rs +++ b/candle-examples/examples/segment-anything/model_prompt_encoder.rs @@ -157,7 +157,28 @@ impl PromptEncoder { let point_embedding = self .pe_layer .forward_with_coords(&points, self.input_image_size)?; - // TODO: tweak based on labels. + let zeros = point_embedding.zeros_like()?; + let point_embeddings = labels.lt(&labels.zeros_like()?)?.where_cond( + &self + .not_a_point_embed + .embeddings() + .broadcast_as(zeros.shape())?, + &point_embedding, + )?; + let labels0 = labels.eq(&labels.zeros_like()?)?.where_cond( + &self.point_embeddings[0] + .embeddings() + .broadcast_as(zeros.shape())?, + &zeros, + )?; + let point_embedding = (point_embedding + labels0)?; + let labels1 = labels.eq(&labels.ones_like()?)?.where_cond( + &self.point_embeddings[1] + .embeddings() + .broadcast_as(zeros.shape())?, + &zeros, + )?; + let point_embedding = (point_embedding + labels1)?; Ok(point_embedding) }