Get the comparison operation to work on scalar values. (#780)

* Get the comparison operation to work on scalar values.

* Add some time measurement.
This commit is contained in:
Laurent Mazare
2023-09-08 20:13:29 +01:00
committed by GitHub
parent 0906acab91
commit acf8f10ae1
5 changed files with 49 additions and 12 deletions

View File

@ -209,12 +209,17 @@ pub fn main() -> anyhow::Result<()> {
}
} else {
let point = Some((args.point_x, args.point_y));
let start_time = std::time::Instant::now();
let (mask, iou_predictions) = sam.forward(&image, point, false)?;
println!(
"mask generated in {:.2}s",
start_time.elapsed().as_secs_f32()
);
println!("mask:\n{mask}");
println!("iou_predictions: {iou_predictions:?}");
// Save the mask as an image.
let mask = (mask.ge(&mask.zeros_like()?)? * 255.)?;
let mask = (mask.ge(0f32)? * 255.)?;
let (_one, h, w) = mask.dims3()?;
let mask = mask.expand((3, h, w))?;
candle_examples::save_image_resize(&mask, "sam_mask.png", initial_h, initial_w)?;