mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
[segment-anything] Support multi-point as the prompt input (#945)
* [sam] Support multi-point prompts * [segment-anything] Pass points by reference * [segment-anything] Update example code and image * Fix clippy lint. --------- Co-authored-by: Yun Ding <yunding@nvidia.com> Co-authored-by: laurent <laurent.mazare@gmail.com>
This commit is contained in:
@ -16,25 +16,30 @@ based on [MobileSAM](https://github.com/ChaoningZhang/MobileSAM).
|
|||||||
cargo run --example segment-anything --release -- \
|
cargo run --example segment-anything --release -- \
|
||||||
--image candle-examples/examples/yolo-v8/assets/bike.jpg
|
--image candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||||
--use-tiny
|
--use-tiny
|
||||||
--point-x 0.4
|
--point-x 0.6,0.6
|
||||||
--point-y 0.3
|
--point-y 0.6,0.55
|
||||||
```
|
```
|
||||||
|
|
||||||
Running this command generates a `sam_merged.jpg` file containing the original
|
Running this command generates a `sam_merged.jpg` file containing the original
|
||||||
image with a blue overlay of the selected mask. The red dot represents the prompt
|
image with a blue overlay of the selected mask. The red dots represent the prompt
|
||||||
specified by `--point-x 0.4 --point-y 0.3`, this prompt is assumed to be part
|
specified by `--point-x 0.6,0.6 --point-y 0.6,0.55`, this prompt is assumed to be part
|
||||||
of the target mask.
|
of the target mask.
|
||||||
|
|
||||||
The values used for `--point-x` and `--point-y` should be between 0 and 1 and
|
The values used for `--point-x` and `--point-y` should be between 0 and 1 and
|
||||||
are proportional to the image dimension, i.e. use 0.5 for the image center.
|
are proportional to the image dimension, i.e. use 0.5 for the image center.
|
||||||
|
|
||||||
|
Original image:
|
||||||

|

|
||||||
|
|
||||||

|
Segment results by prompting with a single point `--point-x 0.6 --point-y 0.55`:
|
||||||
|

|
||||||
|
|
||||||
|
Segment results by prompting with multiple points `--point-x 0.6,0.6 --point-y 0.6,0.55`:
|
||||||
|

|
||||||
|
|
||||||
### Command-line flags
|
### Command-line flags
|
||||||
- `--use-tiny`: use the TinyViT based MobileSAM backbone rather than the default
|
- `--use-tiny`: use the TinyViT based MobileSAM backbone rather than the default
|
||||||
one.
|
one.
|
||||||
- `--point-x`, `--point-y`: specifies the location of the target point.
|
- `--point-x`, `--point-y`: specifies the location of the target points.
|
||||||
- `--threshold`: sets the threshold value to be part of the mask, a negative
|
- `--threshold`: sets the threshold value to be part of the mask, a negative
|
||||||
value results in a larger mask and can be specified via `--threshold=-1.2`.
|
value results in a larger mask and can be specified via `--threshold=-1.2`.
|
||||||
|
Binary file not shown.
After Width: | Height: | Size: 158 KiB |
Binary file not shown.
After Width: | Height: | Size: 158 KiB |
@ -27,13 +27,13 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
generate_masks: bool,
|
generate_masks: bool,
|
||||||
|
|
||||||
/// The target point x coordinate, between 0 and 1 (0.5 is at the middle of the image).
|
/// Comma separated list of x coordinates, between 0 and 1 (0.5 is at the middle of the image).
|
||||||
#[arg(long, default_value_t = 0.5)]
|
#[arg(long, use_value_delimiter = true)]
|
||||||
point_x: f64,
|
point_x: Vec<f64>,
|
||||||
|
|
||||||
/// The target point y coordinate, between 0 and 1 (0.5 is at the middle of the image).
|
/// Comma separated list of y coordinate, between 0 and 1 (0.5 is at the middle of the image).
|
||||||
#[arg(long, default_value_t = 0.5)]
|
#[arg(long, use_value_delimiter = true)]
|
||||||
point_y: f64,
|
point_y: Vec<f64>,
|
||||||
|
|
||||||
/// The detection threshold for the mask, 0 is the default value, negative values mean a larger
|
/// The detection threshold for the mask, 0 is the default value, negative values mean a larger
|
||||||
/// mask, positive makes the mask more selective.
|
/// mask, positive makes the mask more selective.
|
||||||
@ -111,9 +111,16 @@ pub fn main() -> anyhow::Result<()> {
|
|||||||
)?;
|
)?;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
let point = Some((args.point_x, args.point_y));
|
if args.point_x.len() != args.point_y.len() {
|
||||||
|
anyhow::bail!(
|
||||||
|
"number of x coordinates unequal to the number of y coordinates: {} v.s. {}",
|
||||||
|
args.point_x.len(),
|
||||||
|
args.point_y.len()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
let points: Vec<(f64, f64)> = args.point_x.into_iter().zip(args.point_y).collect();
|
||||||
let start_time = std::time::Instant::now();
|
let start_time = std::time::Instant::now();
|
||||||
let (mask, iou_predictions) = sam.forward(&image, point, false)?;
|
let (mask, iou_predictions) = sam.forward(&image, &points, false)?;
|
||||||
println!(
|
println!(
|
||||||
"mask generated in {:.2}s",
|
"mask generated in {:.2}s",
|
||||||
start_time.elapsed().as_secs_f32()
|
start_time.elapsed().as_secs_f32()
|
||||||
@ -151,12 +158,17 @@ pub fn main() -> anyhow::Result<()> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
let (x, y) = (
|
for (x, y) in points {
|
||||||
(args.point_x * img.width() as f64) as i32,
|
let x = (x * img.width() as f64) as i32;
|
||||||
(args.point_y * img.height() as f64) as i32,
|
let y = (y * img.height() as f64) as i32;
|
||||||
|
imageproc::drawing::draw_filled_circle_mut(
|
||||||
|
&mut img,
|
||||||
|
(x, y),
|
||||||
|
3,
|
||||||
|
image::Rgba([255, 0, 0, 200]),
|
||||||
);
|
);
|
||||||
imageproc::drawing::draw_filled_circle(&img, (x, y), 3, image::Rgba([255, 0, 0, 200]))
|
}
|
||||||
.save("sam_merged.jpg")?
|
img.save("sam_merged.jpg")?
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -130,7 +130,7 @@ impl Sam {
|
|||||||
pub fn forward(
|
pub fn forward(
|
||||||
&self,
|
&self,
|
||||||
img: &Tensor,
|
img: &Tensor,
|
||||||
point: Option<(f64, f64)>,
|
points: &[(f64, f64)],
|
||||||
multimask_output: bool,
|
multimask_output: bool,
|
||||||
) -> Result<(Tensor, Tensor)> {
|
) -> Result<(Tensor, Tensor)> {
|
||||||
let (_c, original_h, original_w) = img.dims3()?;
|
let (_c, original_h, original_w) = img.dims3()?;
|
||||||
@ -140,7 +140,7 @@ impl Sam {
|
|||||||
&img_embeddings,
|
&img_embeddings,
|
||||||
original_h,
|
original_h,
|
||||||
original_w,
|
original_w,
|
||||||
point,
|
points,
|
||||||
multimask_output,
|
multimask_output,
|
||||||
)?;
|
)?;
|
||||||
let mask = low_res_mask
|
let mask = low_res_mask
|
||||||
@ -155,20 +155,24 @@ impl Sam {
|
|||||||
img_embeddings: &Tensor,
|
img_embeddings: &Tensor,
|
||||||
original_h: usize,
|
original_h: usize,
|
||||||
original_w: usize,
|
original_w: usize,
|
||||||
point: Option<(f64, f64)>,
|
points: &[(f64, f64)],
|
||||||
multimask_output: bool,
|
multimask_output: bool,
|
||||||
) -> Result<(Tensor, Tensor)> {
|
) -> Result<(Tensor, Tensor)> {
|
||||||
let image_pe = self.prompt_encoder.get_dense_pe()?;
|
let image_pe = self.prompt_encoder.get_dense_pe()?;
|
||||||
let points = match point {
|
let points = if points.is_empty() {
|
||||||
None => None,
|
None
|
||||||
Some((x, y)) => {
|
} else {
|
||||||
let points = Tensor::new(
|
let n_points = points.len();
|
||||||
&[[[x as f32 * original_w as f32, y as f32 * original_h as f32]]],
|
let mut coords = vec![];
|
||||||
img_embeddings.device(),
|
points.iter().for_each(|(x, y)| {
|
||||||
)?;
|
let x = (*x as f32) * (original_w as f32);
|
||||||
let labels = Tensor::ones((1, 1), DType::F32, img_embeddings.device())?;
|
let y = (*y as f32) * (original_h as f32);
|
||||||
|
coords.push(x);
|
||||||
|
coords.push(y);
|
||||||
|
});
|
||||||
|
let points = Tensor::from_vec(coords, (n_points, 1, 2), img_embeddings.device())?;
|
||||||
|
let labels = Tensor::ones((n_points, 1), DType::F32, img_embeddings.device())?;
|
||||||
Some((points, labels))
|
Some((points, labels))
|
||||||
}
|
|
||||||
};
|
};
|
||||||
let points = points.as_ref().map(|(x, y)| (x, y));
|
let points = points.as_ref().map(|(x, y)| (x, y));
|
||||||
let (sparse_prompt_embeddings, dense_prompt_embeddings) =
|
let (sparse_prompt_embeddings, dense_prompt_embeddings) =
|
||||||
|
@ -94,7 +94,7 @@ impl Model {
|
|||||||
&embeddings.data,
|
&embeddings.data,
|
||||||
embeddings.height as usize,
|
embeddings.height as usize,
|
||||||
embeddings.width as usize,
|
embeddings.width as usize,
|
||||||
Some((x, y)),
|
&[(x, y)],
|
||||||
false,
|
false,
|
||||||
)?;
|
)?;
|
||||||
let iou = iou_predictions.flatten(0, 1)?.to_vec1::<f32>()?[0];
|
let iou = iou_predictions.flatten(0, 1)?.to_vec1::<f32>()?[0];
|
||||||
|
Reference in New Issue
Block a user