Use a single flag for the point argument. (#958)

This commit is contained in:
Laurent Mazare
2023-09-25 12:53:24 +01:00
committed by GitHub
parent 7f2bbcf746
commit a36d883254
3 changed files with 31 additions and 31 deletions

View File

@ -16,30 +16,29 @@ based on [MobileSAM](https://github.com/ChaoningZhang/MobileSAM).
cargo run --example segment-anything --release -- \
--image candle-examples/examples/yolo-v8/assets/bike.jpg
--use-tiny
--point-x 0.6,0.6
--point-y 0.6,0.55
--point 0.6,0.6 --point 0.6,0.55
```
Running this command generates a `sam_merged.jpg` file containing the original
image with a blue overlay of the selected mask. The red dots represent the prompt
specified by `--point-x 0.6,0.6 --point-y 0.6,0.55`, this prompt is assumed to be part
specified by `--point 0.6,0.6 --point 0.6,0.55`, this prompt is assumed to be part
of the target mask.
The values used for `--point-x` and `--point-y` should be between 0 and 1 and
are proportional to the image dimension, i.e. use 0.5 for the image center.
The values used for `--point` should be a comma delimited pair of float values.
They are proportional to the image dimension, i.e. use 0.5 for the image center.
Original image:
![Leading group, Giro d'Italia 2021](../yolo-v8/assets/bike.jpg)
Segment results by prompting with a single point `--point-x 0.6 --point-y 0.55`:
Segment results by prompting with a single point `--point 0.6,0.55`:
![Leading group, Giro d'Italia 2021](./assets/single_pt_prompt.jpg)
Segment results by prompting with multiple points `--point-x 0.6,0.6 --point-y 0.6,0.55`:
Segment results by prompting with multiple points `--point 0.6,0.6 --point 0.6,0.55`:
![Leading group, Giro d'Italia 2021](./assets/two_pt_prompt.jpg)
### Command-line flags
- `--use-tiny`: use the TinyViT based MobileSAM backbone rather than the default
one.
- `--point-x`, `--point-y`: specifies the location of the target points.
- `--point`: specifies the location of the target points.
- `--threshold`: sets the threshold value to be part of the mask, a negative
value results in a larger mask and can be specified via `--threshold=-1.2`.

View File

@ -27,13 +27,9 @@ struct Args {
#[arg(long)]
generate_masks: bool,
/// Comma separated list of x coordinates, between 0 and 1 (0.5 is at the middle of the image).
#[arg(long, use_value_delimiter = true)]
point_x: Vec<f64>,
/// Comma separated list of y coordinate, between 0 and 1 (0.5 is at the middle of the image).
#[arg(long, use_value_delimiter = true)]
point_y: Vec<f64>,
/// List of x,y coordinates, between 0 and 1 (0.5 is at the middle of the image).
#[arg(long)]
point: Vec<String>,
/// The detection threshold for the mask, 0 is the default value, negative values mean a larger
/// mask, positive makes the mask more selective.
@ -111,14 +107,18 @@ pub fn main() -> anyhow::Result<()> {
)?;
}
} else {
if args.point_x.len() != args.point_y.len() {
anyhow::bail!(
"number of x coordinates unequal to the number of y coordinates: {} v.s. {}",
args.point_x.len(),
args.point_y.len()
);
let points = args
.point
.iter()
.map(|point| {
use std::str::FromStr;
let xy = point.split(',').collect::<Vec<_>>();
if xy.len() != 2 {
anyhow::bail!("expected format for points is 0.4,0.2")
}
let points: Vec<(f64, f64)> = args.point_x.into_iter().zip(args.point_y).collect();
Ok((f64::from_str(xy[0])?, f64::from_str(xy[1])?))
})
.collect::<anyhow::Result<Vec<_>>>()?;
let start_time = std::time::Instant::now();
let (mask, iou_predictions) = sam.forward(&image, &points, false)?;
println!(

View File

@ -163,14 +163,15 @@ impl Sam {
None
} else {
let n_points = points.len();
let mut coords = vec![];
points.iter().for_each(|(x, y)| {
let xys = points
.iter()
.flat_map(|(x, y)| {
let x = (*x as f32) * (original_w as f32);
let y = (*y as f32) * (original_h as f32);
coords.push(x);
coords.push(y);
});
let points = Tensor::from_vec(coords, (n_points, 1, 2), img_embeddings.device())?;
[x, y]
})
.collect::<Vec<_>>();
let points = Tensor::from_vec(xys, (n_points, 1, 2), img_embeddings.device())?;
let labels = Tensor::ones((n_points, 1), DType::F32, img_embeddings.device())?;
Some((points, labels))
};