mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
[segment-anything] Support multi-point as the prompt input (#945)
* [sam] Support multi-point prompts * [segment-anything] Pass points by reference * [segment-anything] Update example code and image * Fix clippy lint. --------- Co-authored-by: Yun Ding <yunding@nvidia.com> Co-authored-by: laurent <laurent.mazare@gmail.com>
This commit is contained in:
@ -27,13 +27,13 @@ struct Args {
|
||||
#[arg(long)]
|
||||
generate_masks: bool,
|
||||
|
||||
/// The target point x coordinate, between 0 and 1 (0.5 is at the middle of the image).
|
||||
#[arg(long, default_value_t = 0.5)]
|
||||
point_x: f64,
|
||||
/// Comma separated list of x coordinates, between 0 and 1 (0.5 is at the middle of the image).
|
||||
#[arg(long, use_value_delimiter = true)]
|
||||
point_x: Vec<f64>,
|
||||
|
||||
/// The target point y coordinate, between 0 and 1 (0.5 is at the middle of the image).
|
||||
#[arg(long, default_value_t = 0.5)]
|
||||
point_y: f64,
|
||||
/// Comma separated list of y coordinate, between 0 and 1 (0.5 is at the middle of the image).
|
||||
#[arg(long, use_value_delimiter = true)]
|
||||
point_y: Vec<f64>,
|
||||
|
||||
/// The detection threshold for the mask, 0 is the default value, negative values mean a larger
|
||||
/// mask, positive makes the mask more selective.
|
||||
@ -111,9 +111,16 @@ pub fn main() -> anyhow::Result<()> {
|
||||
)?;
|
||||
}
|
||||
} else {
|
||||
let point = Some((args.point_x, args.point_y));
|
||||
if args.point_x.len() != args.point_y.len() {
|
||||
anyhow::bail!(
|
||||
"number of x coordinates unequal to the number of y coordinates: {} v.s. {}",
|
||||
args.point_x.len(),
|
||||
args.point_y.len()
|
||||
);
|
||||
}
|
||||
let points: Vec<(f64, f64)> = args.point_x.into_iter().zip(args.point_y).collect();
|
||||
let start_time = std::time::Instant::now();
|
||||
let (mask, iou_predictions) = sam.forward(&image, point, false)?;
|
||||
let (mask, iou_predictions) = sam.forward(&image, &points, false)?;
|
||||
println!(
|
||||
"mask generated in {:.2}s",
|
||||
start_time.elapsed().as_secs_f32()
|
||||
@ -151,12 +158,17 @@ pub fn main() -> anyhow::Result<()> {
|
||||
}
|
||||
}
|
||||
}
|
||||
let (x, y) = (
|
||||
(args.point_x * img.width() as f64) as i32,
|
||||
(args.point_y * img.height() as f64) as i32,
|
||||
);
|
||||
imageproc::drawing::draw_filled_circle(&img, (x, y), 3, image::Rgba([255, 0, 0, 200]))
|
||||
.save("sam_merged.jpg")?
|
||||
for (x, y) in points {
|
||||
let x = (x * img.width() as f64) as i32;
|
||||
let y = (y * img.height() as f64) as i32;
|
||||
imageproc::drawing::draw_filled_circle_mut(
|
||||
&mut img,
|
||||
(x, y),
|
||||
3,
|
||||
image::Rgba([255, 0, 0, 200]),
|
||||
);
|
||||
}
|
||||
img.save("sam_merged.jpg")?
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
Reference in New Issue
Block a user