From 6203ced49597ae3de66081b98b9cc9fb9b41ba15 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sat, 30 Sep 2023 07:17:42 +0200 Subject: [PATCH] Add negative prompts to segment-anything. (#1000) --- .../examples/segment-anything/main.rs | 33 +++++++++++-------- .../src/models/segment_anything/sam.rs | 17 +++++++--- .../segment-anything/src/bin/m.rs | 2 +- 3 files changed, 34 insertions(+), 18 deletions(-) diff --git a/candle-examples/examples/segment-anything/main.rs b/candle-examples/examples/segment-anything/main.rs index 4d854610..911442a5 100644 --- a/candle-examples/examples/segment-anything/main.rs +++ b/candle-examples/examples/segment-anything/main.rs @@ -27,10 +27,16 @@ struct Args { #[arg(long)] generate_masks: bool, - /// List of x,y coordinates, between 0 and 1 (0.5 is at the middle of the image). + /// List of x,y coordinates, between 0 and 1 (0.5 is at the middle of the image). These points + /// should be part of the generated mask. #[arg(long)] point: Vec, + /// List of x,y coordinates, between 0 and 1 (0.5 is at the middle of the image). These points + /// should not be part of the generated mask and should be part of the background instead. + #[arg(long)] + neg_point: Vec, + /// The detection threshold for the mask, 0 is the default value, negative values mean a larger /// mask, positive makes the mask more selective. #[arg(long, default_value_t = 0.)] @@ -107,16 +113,17 @@ pub fn main() -> anyhow::Result<()> { )?; } } else { - let points = args - .point - .iter() - .map(|point| { + let iter_points = args.point.iter().map(|p| (p, true)); + let iter_neg_points = args.neg_point.iter().map(|p| (p, false)); + let points = iter_points + .chain(iter_neg_points) + .map(|(point, b)| { use std::str::FromStr; let xy = point.split(',').collect::>(); if xy.len() != 2 { anyhow::bail!("expected format for points is 0.4,0.2") } - Ok((f64::from_str(xy[0])?, f64::from_str(xy[1])?)) + Ok((f64::from_str(xy[0])?, f64::from_str(xy[1])?, b)) }) .collect::>>()?; let start_time = std::time::Instant::now(); @@ -158,15 +165,15 @@ pub fn main() -> anyhow::Result<()> { } } } - for (x, y) in points { + for (x, y, b) in points { let x = (x * img.width() as f64) as i32; let y = (y * img.height() as f64) as i32; - imageproc::drawing::draw_filled_circle_mut( - &mut img, - (x, y), - 3, - image::Rgba([255, 0, 0, 200]), - ); + let color = if b { + image::Rgba([255, 0, 0, 200]) + } else { + image::Rgba([0, 255, 0, 200]) + }; + imageproc::drawing::draw_filled_circle_mut(&mut img, (x, y), 3, color); } img.save("sam_merged.jpg")? } diff --git a/candle-transformers/src/models/segment_anything/sam.rs b/candle-transformers/src/models/segment_anything/sam.rs index 6de7beb2..ccf9ca7a 100644 --- a/candle-transformers/src/models/segment_anything/sam.rs +++ b/candle-transformers/src/models/segment_anything/sam.rs @@ -130,7 +130,7 @@ impl Sam { pub fn forward( &self, img: &Tensor, - points: &[(f64, f64)], + points: &[(f64, f64, bool)], multimask_output: bool, ) -> Result<(Tensor, Tensor)> { let (_c, original_h, original_w) = img.dims3()?; @@ -150,12 +150,17 @@ impl Sam { Ok((mask, iou)) } + /// Generate the mask and IOU predictions from some image embeddings and prompt. + /// + /// The prompt is specified as a list of points `(x, y, b)`. `x` and `y` are the point + /// coordinates (between 0 and 1) and `b` is `true` for points that should be part of the mask + /// and `false` for points that should be part of the background and so excluded from the mask. pub fn forward_for_embeddings( &self, img_embeddings: &Tensor, original_h: usize, original_w: usize, - points: &[(f64, f64)], + points: &[(f64, f64, bool)], multimask_output: bool, ) -> Result<(Tensor, Tensor)> { let image_pe = self.prompt_encoder.get_dense_pe()?; @@ -165,14 +170,18 @@ impl Sam { let n_points = points.len(); let xys = points .iter() - .flat_map(|(x, y)| { + .flat_map(|(x, y, _b)| { let x = (*x as f32) * (original_w as f32); let y = (*y as f32) * (original_h as f32); [x, y] }) .collect::>(); + let labels = points + .iter() + .map(|(_x, _y, b)| if *b { 1f32 } else { 0f32 }) + .collect::>(); let points = Tensor::from_vec(xys, (1, n_points, 2), img_embeddings.device())?; - let labels = Tensor::ones((1, n_points), DType::F32, img_embeddings.device())?; + let labels = Tensor::from_vec(labels, (1, n_points), img_embeddings.device())?; Some((points, labels)) }; let points = points.as_ref().map(|(x, y)| (x, y)); diff --git a/candle-wasm-examples/segment-anything/src/bin/m.rs b/candle-wasm-examples/segment-anything/src/bin/m.rs index a56e48c2..12349493 100644 --- a/candle-wasm-examples/segment-anything/src/bin/m.rs +++ b/candle-wasm-examples/segment-anything/src/bin/m.rs @@ -94,7 +94,7 @@ impl Model { &embeddings.data, embeddings.height as usize, embeddings.width as usize, - &[(x, y)], + &[(x, y, true)], false, )?; let iou = iou_predictions.flatten(0, 1)?.to_vec1::()?[0];