From 90e077e4093510917e5762de92384ae312162f23 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sun, 10 Sep 2023 13:03:02 +0100 Subject: [PATCH] Return the low res mask in the wasm segment-anything module. (#798) * Return the low res mask. * Add some validations. --- .../src/models/segment_anything/sam.rs | 18 +++++++++--------- .../segment-anything/src/bin/m.rs | 12 +++++++++++- 2 files changed, 20 insertions(+), 10 deletions(-) diff --git a/candle-transformers/src/models/segment_anything/sam.rs b/candle-transformers/src/models/segment_anything/sam.rs index 92756591..07e9a759 100644 --- a/candle-transformers/src/models/segment_anything/sam.rs +++ b/candle-transformers/src/models/segment_anything/sam.rs @@ -136,13 +136,18 @@ impl Sam { let (_c, original_h, original_w) = img.dims3()?; let img = self.preprocess(img)?.unsqueeze(0)?; let img_embeddings = self.image_encoder.forward(&img)?; - self.forward_for_embeddings( + let (low_res_mask, iou) = self.forward_for_embeddings( &img_embeddings, original_h, original_w, point, multimask_output, - ) + )?; + let mask = low_res_mask + .upsample_nearest2d(IMAGE_SIZE, IMAGE_SIZE)? + .get(0)? + .i((.., ..original_h, ..original_w))?; + Ok((mask, iou)) } pub fn forward_for_embeddings( @@ -168,18 +173,13 @@ impl Sam { let points = points.as_ref().map(|(x, y)| (x, y)); let (sparse_prompt_embeddings, dense_prompt_embeddings) = self.prompt_encoder.forward(points, None, None)?; - let (low_res_mask, iou_predictions) = self.mask_decoder.forward( + self.mask_decoder.forward( img_embeddings, &image_pe, &sparse_prompt_embeddings, &dense_prompt_embeddings, multimask_output, - )?; - let mask = low_res_mask - .upsample_nearest2d(IMAGE_SIZE, IMAGE_SIZE)? - .get(0)? - .i((.., ..original_h, ..original_w))?; - Ok((mask, iou_predictions)) + ) } pub fn unpreprocess(&self, img: &Tensor) -> Result { diff --git a/candle-wasm-examples/segment-anything/src/bin/m.rs b/candle-wasm-examples/segment-anything/src/bin/m.rs index c4c79fe0..b53f5b9b 100644 --- a/candle-wasm-examples/segment-anything/src/bin/m.rs +++ b/candle-wasm-examples/segment-anything/src/bin/m.rs @@ -77,8 +77,18 @@ impl Model { // x and y have to be between 0 and 1 pub fn mask_for_point(&self, x: f64, y: f64) -> Result { + if !(0. ..=1.).contains(&x) { + Err(JsError::new(&format!( + "x has to be between 0 and 1, got {x}" + )))? + } + if !(0. ..=1.).contains(&y) { + Err(JsError::new(&format!( + "y has to be between 0 and 1, got {y}" + )))? + } let embeddings = match &self.embeddings { - None => todo!(), + None => Err(JsError::new("image embeddings have not been set"))?, Some(embeddings) => embeddings, }; let (mask, iou_predictions) = self.sam.forward_for_embeddings(