[segment-anything] Support multi-point as the prompt input (#945)

* [sam] Support multi-point prompts

* [segment-anything] Pass points by reference

* [segment-anything] Update example code and image

* Fix clippy lint.

---------

Co-authored-by: Yun Ding <yunding@nvidia.com>
Co-authored-by: laurent <laurent.mazare@gmail.com>
This commit is contained in:
GeauxEric
2023-09-25 04:14:10 -07:00
committed by GitHub
parent dc47224ab9
commit 7f2bbcf746
6 changed files with 55 additions and 34 deletions

View File

@ -16,25 +16,30 @@ based on [MobileSAM](https://github.com/ChaoningZhang/MobileSAM).
cargo run --example segment-anything --release -- \ cargo run --example segment-anything --release -- \
--image candle-examples/examples/yolo-v8/assets/bike.jpg --image candle-examples/examples/yolo-v8/assets/bike.jpg
--use-tiny --use-tiny
--point-x 0.4 --point-x 0.6,0.6
--point-y 0.3 --point-y 0.6,0.55
``` ```
Running this command generates a `sam_merged.jpg` file containing the original Running this command generates a `sam_merged.jpg` file containing the original
image with a blue overlay of the selected mask. The red dot represents the prompt image with a blue overlay of the selected mask. The red dots represent the prompt
specified by `--point-x 0.4 --point-y 0.3`, this prompt is assumed to be part specified by `--point-x 0.6,0.6 --point-y 0.6,0.55`, this prompt is assumed to be part
of the target mask. of the target mask.
The values used for `--point-x` and `--point-y` should be between 0 and 1 and The values used for `--point-x` and `--point-y` should be between 0 and 1 and
are proportional to the image dimension, i.e. use 0.5 for the image center. are proportional to the image dimension, i.e. use 0.5 for the image center.
Original image:
![Leading group, Giro d'Italia 2021](../yolo-v8/assets/bike.jpg) ![Leading group, Giro d'Italia 2021](../yolo-v8/assets/bike.jpg)
![Leading group, Giro d'Italia 2021](./assets/sam_merged.jpg) Segment results by prompting with a single point `--point-x 0.6 --point-y 0.55`:
![Leading group, Giro d'Italia 2021](./assets/single_pt_prompt.jpg)
Segment results by prompting with multiple points `--point-x 0.6,0.6 --point-y 0.6,0.55`:
![Leading group, Giro d'Italia 2021](./assets/two_pt_prompt.jpg)
### Command-line flags ### Command-line flags
- `--use-tiny`: use the TinyViT based MobileSAM backbone rather than the default - `--use-tiny`: use the TinyViT based MobileSAM backbone rather than the default
one. one.
- `--point-x`, `--point-y`: specifies the location of the target point. - `--point-x`, `--point-y`: specifies the location of the target points.
- `--threshold`: sets the threshold value to be part of the mask, a negative - `--threshold`: sets the threshold value to be part of the mask, a negative
value results in a larger mask and can be specified via `--threshold=-1.2`. value results in a larger mask and can be specified via `--threshold=-1.2`.

Binary file not shown.

After

Width:  |  Height:  |  Size: 158 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 158 KiB

View File

@ -27,13 +27,13 @@ struct Args {
#[arg(long)] #[arg(long)]
generate_masks: bool, generate_masks: bool,
/// The target point x coordinate, between 0 and 1 (0.5 is at the middle of the image). /// Comma separated list of x coordinates, between 0 and 1 (0.5 is at the middle of the image).
#[arg(long, default_value_t = 0.5)] #[arg(long, use_value_delimiter = true)]
point_x: f64, point_x: Vec<f64>,
/// The target point y coordinate, between 0 and 1 (0.5 is at the middle of the image). /// Comma separated list of y coordinate, between 0 and 1 (0.5 is at the middle of the image).
#[arg(long, default_value_t = 0.5)] #[arg(long, use_value_delimiter = true)]
point_y: f64, point_y: Vec<f64>,
/// The detection threshold for the mask, 0 is the default value, negative values mean a larger /// The detection threshold for the mask, 0 is the default value, negative values mean a larger
/// mask, positive makes the mask more selective. /// mask, positive makes the mask more selective.
@ -111,9 +111,16 @@ pub fn main() -> anyhow::Result<()> {
)?; )?;
} }
} else { } else {
let point = Some((args.point_x, args.point_y)); if args.point_x.len() != args.point_y.len() {
anyhow::bail!(
"number of x coordinates unequal to the number of y coordinates: {} v.s. {}",
args.point_x.len(),
args.point_y.len()
);
}
let points: Vec<(f64, f64)> = args.point_x.into_iter().zip(args.point_y).collect();
let start_time = std::time::Instant::now(); let start_time = std::time::Instant::now();
let (mask, iou_predictions) = sam.forward(&image, point, false)?; let (mask, iou_predictions) = sam.forward(&image, &points, false)?;
println!( println!(
"mask generated in {:.2}s", "mask generated in {:.2}s",
start_time.elapsed().as_secs_f32() start_time.elapsed().as_secs_f32()
@ -151,12 +158,17 @@ pub fn main() -> anyhow::Result<()> {
} }
} }
} }
let (x, y) = ( for (x, y) in points {
(args.point_x * img.width() as f64) as i32, let x = (x * img.width() as f64) as i32;
(args.point_y * img.height() as f64) as i32, let y = (y * img.height() as f64) as i32;
); imageproc::drawing::draw_filled_circle_mut(
imageproc::drawing::draw_filled_circle(&img, (x, y), 3, image::Rgba([255, 0, 0, 200])) &mut img,
.save("sam_merged.jpg")? (x, y),
3,
image::Rgba([255, 0, 0, 200]),
);
}
img.save("sam_merged.jpg")?
} }
Ok(()) Ok(())
} }

View File

@ -130,7 +130,7 @@ impl Sam {
pub fn forward( pub fn forward(
&self, &self,
img: &Tensor, img: &Tensor,
point: Option<(f64, f64)>, points: &[(f64, f64)],
multimask_output: bool, multimask_output: bool,
) -> Result<(Tensor, Tensor)> { ) -> Result<(Tensor, Tensor)> {
let (_c, original_h, original_w) = img.dims3()?; let (_c, original_h, original_w) = img.dims3()?;
@ -140,7 +140,7 @@ impl Sam {
&img_embeddings, &img_embeddings,
original_h, original_h,
original_w, original_w,
point, points,
multimask_output, multimask_output,
)?; )?;
let mask = low_res_mask let mask = low_res_mask
@ -155,20 +155,24 @@ impl Sam {
img_embeddings: &Tensor, img_embeddings: &Tensor,
original_h: usize, original_h: usize,
original_w: usize, original_w: usize,
point: Option<(f64, f64)>, points: &[(f64, f64)],
multimask_output: bool, multimask_output: bool,
) -> Result<(Tensor, Tensor)> { ) -> Result<(Tensor, Tensor)> {
let image_pe = self.prompt_encoder.get_dense_pe()?; let image_pe = self.prompt_encoder.get_dense_pe()?;
let points = match point { let points = if points.is_empty() {
None => None, None
Some((x, y)) => { } else {
let points = Tensor::new( let n_points = points.len();
&[[[x as f32 * original_w as f32, y as f32 * original_h as f32]]], let mut coords = vec![];
img_embeddings.device(), points.iter().for_each(|(x, y)| {
)?; let x = (*x as f32) * (original_w as f32);
let labels = Tensor::ones((1, 1), DType::F32, img_embeddings.device())?; let y = (*y as f32) * (original_h as f32);
Some((points, labels)) coords.push(x);
} coords.push(y);
});
let points = Tensor::from_vec(coords, (n_points, 1, 2), img_embeddings.device())?;
let labels = Tensor::ones((n_points, 1), DType::F32, img_embeddings.device())?;
Some((points, labels))
}; };
let points = points.as_ref().map(|(x, y)| (x, y)); let points = points.as_ref().map(|(x, y)| (x, y));
let (sparse_prompt_embeddings, dense_prompt_embeddings) = let (sparse_prompt_embeddings, dense_prompt_embeddings) =

View File

@ -94,7 +94,7 @@ impl Model {
&embeddings.data, &embeddings.data,
embeddings.height as usize, embeddings.height as usize,
embeddings.width as usize, embeddings.width as usize,
Some((x, y)), &[(x, y)],
false, false,
)?; )?;
let iou = iou_predictions.flatten(0, 1)?.to_vec1::<f32>()?[0]; let iou = iou_predictions.flatten(0, 1)?.to_vec1::<f32>()?[0];