Use a single flag for the point argument. (#958)

This commit is contained in:
Laurent Mazare
2023-09-25 12:53:24 +01:00
committed by GitHub
parent 7f2bbcf746
commit a36d883254
3 changed files with 31 additions and 31 deletions

View File

@ -16,30 +16,29 @@ based on [MobileSAM](https://github.com/ChaoningZhang/MobileSAM).
cargo run --example segment-anything --release -- \ cargo run --example segment-anything --release -- \
--image candle-examples/examples/yolo-v8/assets/bike.jpg --image candle-examples/examples/yolo-v8/assets/bike.jpg
--use-tiny --use-tiny
--point-x 0.6,0.6 --point 0.6,0.6 --point 0.6,0.55
--point-y 0.6,0.55
``` ```
Running this command generates a `sam_merged.jpg` file containing the original Running this command generates a `sam_merged.jpg` file containing the original
image with a blue overlay of the selected mask. The red dots represent the prompt image with a blue overlay of the selected mask. The red dots represent the prompt
specified by `--point-x 0.6,0.6 --point-y 0.6,0.55`, this prompt is assumed to be part specified by `--point 0.6,0.6 --point 0.6,0.55`, this prompt is assumed to be part
of the target mask. of the target mask.
The values used for `--point-x` and `--point-y` should be between 0 and 1 and The values used for `--point` should be a comma delimited pair of float values.
are proportional to the image dimension, i.e. use 0.5 for the image center. They are proportional to the image dimension, i.e. use 0.5 for the image center.
Original image: Original image:
![Leading group, Giro d'Italia 2021](../yolo-v8/assets/bike.jpg) ![Leading group, Giro d'Italia 2021](../yolo-v8/assets/bike.jpg)
Segment results by prompting with a single point `--point-x 0.6 --point-y 0.55`: Segment results by prompting with a single point `--point 0.6,0.55`:
![Leading group, Giro d'Italia 2021](./assets/single_pt_prompt.jpg) ![Leading group, Giro d'Italia 2021](./assets/single_pt_prompt.jpg)
Segment results by prompting with multiple points `--point-x 0.6,0.6 --point-y 0.6,0.55`: Segment results by prompting with multiple points `--point 0.6,0.6 --point 0.6,0.55`:
![Leading group, Giro d'Italia 2021](./assets/two_pt_prompt.jpg) ![Leading group, Giro d'Italia 2021](./assets/two_pt_prompt.jpg)
### Command-line flags ### Command-line flags
- `--use-tiny`: use the TinyViT based MobileSAM backbone rather than the default - `--use-tiny`: use the TinyViT based MobileSAM backbone rather than the default
one. one.
- `--point-x`, `--point-y`: specifies the location of the target points. - `--point`: specifies the location of the target points.
- `--threshold`: sets the threshold value to be part of the mask, a negative - `--threshold`: sets the threshold value to be part of the mask, a negative
value results in a larger mask and can be specified via `--threshold=-1.2`. value results in a larger mask and can be specified via `--threshold=-1.2`.

View File

@ -27,13 +27,9 @@ struct Args {
#[arg(long)] #[arg(long)]
generate_masks: bool, generate_masks: bool,
/// Comma separated list of x coordinates, between 0 and 1 (0.5 is at the middle of the image). /// List of x,y coordinates, between 0 and 1 (0.5 is at the middle of the image).
#[arg(long, use_value_delimiter = true)] #[arg(long)]
point_x: Vec<f64>, point: Vec<String>,
/// Comma separated list of y coordinate, between 0 and 1 (0.5 is at the middle of the image).
#[arg(long, use_value_delimiter = true)]
point_y: Vec<f64>,
/// The detection threshold for the mask, 0 is the default value, negative values mean a larger /// The detection threshold for the mask, 0 is the default value, negative values mean a larger
/// mask, positive makes the mask more selective. /// mask, positive makes the mask more selective.
@ -111,14 +107,18 @@ pub fn main() -> anyhow::Result<()> {
)?; )?;
} }
} else { } else {
if args.point_x.len() != args.point_y.len() { let points = args
anyhow::bail!( .point
"number of x coordinates unequal to the number of y coordinates: {} v.s. {}", .iter()
args.point_x.len(), .map(|point| {
args.point_y.len() use std::str::FromStr;
); let xy = point.split(',').collect::<Vec<_>>();
if xy.len() != 2 {
anyhow::bail!("expected format for points is 0.4,0.2")
} }
let points: Vec<(f64, f64)> = args.point_x.into_iter().zip(args.point_y).collect(); Ok((f64::from_str(xy[0])?, f64::from_str(xy[1])?))
})
.collect::<anyhow::Result<Vec<_>>>()?;
let start_time = std::time::Instant::now(); let start_time = std::time::Instant::now();
let (mask, iou_predictions) = sam.forward(&image, &points, false)?; let (mask, iou_predictions) = sam.forward(&image, &points, false)?;
println!( println!(

View File

@ -163,14 +163,15 @@ impl Sam {
None None
} else { } else {
let n_points = points.len(); let n_points = points.len();
let mut coords = vec![]; let xys = points
points.iter().for_each(|(x, y)| { .iter()
.flat_map(|(x, y)| {
let x = (*x as f32) * (original_w as f32); let x = (*x as f32) * (original_w as f32);
let y = (*y as f32) * (original_h as f32); let y = (*y as f32) * (original_h as f32);
coords.push(x); [x, y]
coords.push(y); })
}); .collect::<Vec<_>>();
let points = Tensor::from_vec(coords, (n_points, 1, 2), img_embeddings.device())?; let points = Tensor::from_vec(xys, (n_points, 1, 2), img_embeddings.device())?;
let labels = Tensor::ones((n_points, 1), DType::F32, img_embeddings.device())?; let labels = Tensor::ones((n_points, 1), DType::F32, img_embeddings.device())?;
Some((points, labels)) Some((points, labels))
}; };