Use a single flag for the point argument. (#958)

This commit is contained in:
Laurent Mazare
2023-09-25 12:53:24 +01:00
committed by GitHub
parent 7f2bbcf746
commit a36d883254
3 changed files with 31 additions and 31 deletions

View File

@ -27,13 +27,9 @@ struct Args {
#[arg(long)]
generate_masks: bool,
/// Comma separated list of x coordinates, between 0 and 1 (0.5 is at the middle of the image).
#[arg(long, use_value_delimiter = true)]
point_x: Vec<f64>,
/// Comma separated list of y coordinate, between 0 and 1 (0.5 is at the middle of the image).
#[arg(long, use_value_delimiter = true)]
point_y: Vec<f64>,
/// List of x,y coordinates, between 0 and 1 (0.5 is at the middle of the image).
#[arg(long)]
point: Vec<String>,
/// The detection threshold for the mask, 0 is the default value, negative values mean a larger
/// mask, positive makes the mask more selective.
@ -111,14 +107,18 @@ pub fn main() -> anyhow::Result<()> {
)?;
}
} else {
if args.point_x.len() != args.point_y.len() {
anyhow::bail!(
"number of x coordinates unequal to the number of y coordinates: {} v.s. {}",
args.point_x.len(),
args.point_y.len()
);
}
let points: Vec<(f64, f64)> = args.point_x.into_iter().zip(args.point_y).collect();
let points = args
.point
.iter()
.map(|point| {
use std::str::FromStr;
let xy = point.split(',').collect::<Vec<_>>();
if xy.len() != 2 {
anyhow::bail!("expected format for points is 0.4,0.2")
}
Ok((f64::from_str(xy[0])?, f64::from_str(xy[1])?))
})
.collect::<anyhow::Result<Vec<_>>>()?;
let start_time = std::time::Instant::now();
let (mask, iou_predictions) = sam.forward(&image, &points, false)?;
println!(