Draw the mask on a merged image. (#775)

* Draw the mask on a merged image.

* Clippy fix.

* Enable the target point by default.

* Add to the readme.
This commit is contained in:
Laurent Mazare
2023-09-08 14:04:34 +01:00
committed by GitHub
parent 98172d46fa
commit e5703d2f56
2 changed files with 53 additions and 9 deletions

View File

@ -104,11 +104,11 @@ struct Args {
#[arg(long)]
generate_masks: bool,
#[arg(long)]
point_x: Option<f64>,
#[arg(long, default_value_t = 0.5)]
point_x: f64,
#[arg(long)]
point_y: Option<f64>,
#[arg(long, default_value_t = 0.5)]
point_y: f64,
}
pub fn main() -> anyhow::Result<()> {
@ -135,7 +135,7 @@ pub fn main() -> anyhow::Result<()> {
let (_c, h, w) = image.dims3()?;
(image, h, w)
} else {
let (image, h, w) = candle_examples::load_image(args.image, Some(model_sam::IMAGE_SIZE))?;
let (image, h, w) = candle_examples::load_image(&args.image, Some(model_sam::IMAGE_SIZE))?;
(image.to_device(&device)?, h, w)
};
println!("loaded image {image:?}");
@ -163,7 +163,7 @@ pub fn main() -> anyhow::Result<()> {
/* crop_n_points_downscale_factor */ 1,
)?
} else {
let point = args.point_x.zip(args.point_y);
let point = Some((args.point_x, args.point_y));
let (mask, iou_predictions) = sam.forward(&image, point, false)?;
println!("mask:\n{mask}");
println!("iou_predictions: {iou_predictions:?}");
@ -174,9 +174,50 @@ pub fn main() -> anyhow::Result<()> {
let mask = mask.expand((3, h, w))?;
candle_examples::save_image_resize(&mask, "sam_mask.png", initial_h, initial_w)?;
let image = sam.preprocess(&image)?;
let image = sam.unpreprocess(&image)?.to_dtype(DType::U8)?;
candle_examples::save_image(&image, "sam_input_scaled.png")?;
if !args.image.ends_with(".safetensors") {
let mut img = image::io::Reader::open(&args.image)?
.decode()
.map_err(candle::Error::wrap)?;
let mask_pixels = mask.permute((1, 2, 0))?.flatten_all()?.to_vec1::<u8>()?;
let mask_img: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
match image::ImageBuffer::from_raw(w as u32, h as u32, mask_pixels) {
Some(image) => image,
None => anyhow::bail!("error saving merged image"),
};
let mask_img = image::DynamicImage::from(mask_img).resize_to_fill(
img.width(),
img.height(),
image::imageops::FilterType::CatmullRom,
);
for x in 0..img.width() {
for y in 0..img.height() {
let mask_p = imageproc::drawing::Canvas::get_pixel(&mask_img, x, y);
if mask_p.0[0] > 100 {
let mut img_p = imageproc::drawing::Canvas::get_pixel(&img, x, y);
img_p.0[2] = 255 - (255 - img_p.0[2]) / 2;
img_p.0[1] /= 2;
img_p.0[0] /= 2;
imageproc::drawing::Canvas::draw_pixel(&mut img, x, y, img_p)
}
}
}
match point {
Some((x, y)) => {
let (x, y) = (
(x * img.width() as f64) as i32,
(y * img.height() as f64) as i32,
);
imageproc::drawing::draw_filled_circle(
&img,
(x, y),
3,
image::Rgba([255, 0, 0, 200]),
)
.save("sam_merged.jpg")?
}
None => img.save("sam_merged.jpg")?,
};
}
}
Ok(())
}