mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Use a single flag for the point argument. (#958)
This commit is contained in:
@ -27,13 +27,9 @@ struct Args {
|
||||
#[arg(long)]
|
||||
generate_masks: bool,
|
||||
|
||||
/// Comma separated list of x coordinates, between 0 and 1 (0.5 is at the middle of the image).
|
||||
#[arg(long, use_value_delimiter = true)]
|
||||
point_x: Vec<f64>,
|
||||
|
||||
/// Comma separated list of y coordinate, between 0 and 1 (0.5 is at the middle of the image).
|
||||
#[arg(long, use_value_delimiter = true)]
|
||||
point_y: Vec<f64>,
|
||||
/// List of x,y coordinates, between 0 and 1 (0.5 is at the middle of the image).
|
||||
#[arg(long)]
|
||||
point: Vec<String>,
|
||||
|
||||
/// The detection threshold for the mask, 0 is the default value, negative values mean a larger
|
||||
/// mask, positive makes the mask more selective.
|
||||
@ -111,14 +107,18 @@ pub fn main() -> anyhow::Result<()> {
|
||||
)?;
|
||||
}
|
||||
} else {
|
||||
if args.point_x.len() != args.point_y.len() {
|
||||
anyhow::bail!(
|
||||
"number of x coordinates unequal to the number of y coordinates: {} v.s. {}",
|
||||
args.point_x.len(),
|
||||
args.point_y.len()
|
||||
);
|
||||
}
|
||||
let points: Vec<(f64, f64)> = args.point_x.into_iter().zip(args.point_y).collect();
|
||||
let points = args
|
||||
.point
|
||||
.iter()
|
||||
.map(|point| {
|
||||
use std::str::FromStr;
|
||||
let xy = point.split(',').collect::<Vec<_>>();
|
||||
if xy.len() != 2 {
|
||||
anyhow::bail!("expected format for points is 0.4,0.2")
|
||||
}
|
||||
Ok((f64::from_str(xy[0])?, f64::from_str(xy[1])?))
|
||||
})
|
||||
.collect::<anyhow::Result<Vec<_>>>()?;
|
||||
let start_time = std::time::Instant::now();
|
||||
let (mask, iou_predictions) = sam.forward(&image, &points, false)?;
|
||||
println!(
|
||||
|
Reference in New Issue
Block a user