diff --git a/candle-examples/examples/segment-anything/README.md b/candle-examples/examples/segment-anything/README.md index 3c5b034f..1b255bd2 100644 --- a/candle-examples/examples/segment-anything/README.md +++ b/candle-examples/examples/segment-anything/README.md @@ -16,25 +16,30 @@ based on [MobileSAM](https://github.com/ChaoningZhang/MobileSAM). cargo run --example segment-anything --release -- \ --image candle-examples/examples/yolo-v8/assets/bike.jpg --use-tiny - --point-x 0.4 - --point-y 0.3 + --point-x 0.6,0.6 + --point-y 0.6,0.55 ``` Running this command generates a `sam_merged.jpg` file containing the original -image with a blue overlay of the selected mask. The red dot represents the prompt -specified by `--point-x 0.4 --point-y 0.3`, this prompt is assumed to be part +image with a blue overlay of the selected mask. The red dots represent the prompt +specified by `--point-x 0.6,0.6 --point-y 0.6,0.55`, this prompt is assumed to be part of the target mask. The values used for `--point-x` and `--point-y` should be between 0 and 1 and are proportional to the image dimension, i.e. use 0.5 for the image center. +Original image: ![Leading group, Giro d'Italia 2021](../yolo-v8/assets/bike.jpg) -![Leading group, Giro d'Italia 2021](./assets/sam_merged.jpg) +Segment results by prompting with a single point `--point-x 0.6 --point-y 0.55`: +![Leading group, Giro d'Italia 2021](./assets/single_pt_prompt.jpg) + +Segment results by prompting with multiple points `--point-x 0.6,0.6 --point-y 0.6,0.55`: +![Leading group, Giro d'Italia 2021](./assets/two_pt_prompt.jpg) ### Command-line flags - `--use-tiny`: use the TinyViT based MobileSAM backbone rather than the default one. -- `--point-x`, `--point-y`: specifies the location of the target point. +- `--point-x`, `--point-y`: specifies the location of the target points. - `--threshold`: sets the threshold value to be part of the mask, a negative value results in a larger mask and can be specified via `--threshold=-1.2`. diff --git a/candle-examples/examples/segment-anything/assets/single_pt_prompt.jpg b/candle-examples/examples/segment-anything/assets/single_pt_prompt.jpg new file mode 100644 index 00000000..d4ace73b Binary files /dev/null and b/candle-examples/examples/segment-anything/assets/single_pt_prompt.jpg differ diff --git a/candle-examples/examples/segment-anything/assets/two_pt_prompt.jpg b/candle-examples/examples/segment-anything/assets/two_pt_prompt.jpg new file mode 100644 index 00000000..0d1b2af4 Binary files /dev/null and b/candle-examples/examples/segment-anything/assets/two_pt_prompt.jpg differ diff --git a/candle-examples/examples/segment-anything/main.rs b/candle-examples/examples/segment-anything/main.rs index 71abe116..1eaf4b7c 100644 --- a/candle-examples/examples/segment-anything/main.rs +++ b/candle-examples/examples/segment-anything/main.rs @@ -27,13 +27,13 @@ struct Args { #[arg(long)] generate_masks: bool, - /// The target point x coordinate, between 0 and 1 (0.5 is at the middle of the image). - #[arg(long, default_value_t = 0.5)] - point_x: f64, + /// Comma separated list of x coordinates, between 0 and 1 (0.5 is at the middle of the image). + #[arg(long, use_value_delimiter = true)] + point_x: Vec, - /// The target point y coordinate, between 0 and 1 (0.5 is at the middle of the image). - #[arg(long, default_value_t = 0.5)] - point_y: f64, + /// Comma separated list of y coordinate, between 0 and 1 (0.5 is at the middle of the image). + #[arg(long, use_value_delimiter = true)] + point_y: Vec, /// The detection threshold for the mask, 0 is the default value, negative values mean a larger /// mask, positive makes the mask more selective. @@ -111,9 +111,16 @@ pub fn main() -> anyhow::Result<()> { )?; } } else { - let point = Some((args.point_x, args.point_y)); + if args.point_x.len() != args.point_y.len() { + anyhow::bail!( + "number of x coordinates unequal to the number of y coordinates: {} v.s. {}", + args.point_x.len(), + args.point_y.len() + ); + } + let points: Vec<(f64, f64)> = args.point_x.into_iter().zip(args.point_y).collect(); let start_time = std::time::Instant::now(); - let (mask, iou_predictions) = sam.forward(&image, point, false)?; + let (mask, iou_predictions) = sam.forward(&image, &points, false)?; println!( "mask generated in {:.2}s", start_time.elapsed().as_secs_f32() @@ -151,12 +158,17 @@ pub fn main() -> anyhow::Result<()> { } } } - let (x, y) = ( - (args.point_x * img.width() as f64) as i32, - (args.point_y * img.height() as f64) as i32, - ); - imageproc::drawing::draw_filled_circle(&img, (x, y), 3, image::Rgba([255, 0, 0, 200])) - .save("sam_merged.jpg")? + for (x, y) in points { + let x = (x * img.width() as f64) as i32; + let y = (y * img.height() as f64) as i32; + imageproc::drawing::draw_filled_circle_mut( + &mut img, + (x, y), + 3, + image::Rgba([255, 0, 0, 200]), + ); + } + img.save("sam_merged.jpg")? } Ok(()) } diff --git a/candle-transformers/src/models/segment_anything/sam.rs b/candle-transformers/src/models/segment_anything/sam.rs index 07e9a759..0c456b79 100644 --- a/candle-transformers/src/models/segment_anything/sam.rs +++ b/candle-transformers/src/models/segment_anything/sam.rs @@ -130,7 +130,7 @@ impl Sam { pub fn forward( &self, img: &Tensor, - point: Option<(f64, f64)>, + points: &[(f64, f64)], multimask_output: bool, ) -> Result<(Tensor, Tensor)> { let (_c, original_h, original_w) = img.dims3()?; @@ -140,7 +140,7 @@ impl Sam { &img_embeddings, original_h, original_w, - point, + points, multimask_output, )?; let mask = low_res_mask @@ -155,20 +155,24 @@ impl Sam { img_embeddings: &Tensor, original_h: usize, original_w: usize, - point: Option<(f64, f64)>, + points: &[(f64, f64)], multimask_output: bool, ) -> Result<(Tensor, Tensor)> { let image_pe = self.prompt_encoder.get_dense_pe()?; - let points = match point { - None => None, - Some((x, y)) => { - let points = Tensor::new( - &[[[x as f32 * original_w as f32, y as f32 * original_h as f32]]], - img_embeddings.device(), - )?; - let labels = Tensor::ones((1, 1), DType::F32, img_embeddings.device())?; - Some((points, labels)) - } + let points = if points.is_empty() { + None + } else { + let n_points = points.len(); + let mut coords = vec![]; + points.iter().for_each(|(x, y)| { + let x = (*x as f32) * (original_w as f32); + let y = (*y as f32) * (original_h as f32); + coords.push(x); + coords.push(y); + }); + let points = Tensor::from_vec(coords, (n_points, 1, 2), img_embeddings.device())?; + let labels = Tensor::ones((n_points, 1), DType::F32, img_embeddings.device())?; + Some((points, labels)) }; let points = points.as_ref().map(|(x, y)| (x, y)); let (sparse_prompt_embeddings, dense_prompt_embeddings) = diff --git a/candle-wasm-examples/segment-anything/src/bin/m.rs b/candle-wasm-examples/segment-anything/src/bin/m.rs index acd903b0..a56e48c2 100644 --- a/candle-wasm-examples/segment-anything/src/bin/m.rs +++ b/candle-wasm-examples/segment-anything/src/bin/m.rs @@ -94,7 +94,7 @@ impl Model { &embeddings.data, embeddings.height as usize, embeddings.width as usize, - Some((x, y)), + &[(x, y)], false, )?; let iou = iou_predictions.flatten(0, 1)?.to_vec1::()?[0];