mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Draw the mask on a merged image. (#775)
* Draw the mask on a merged image. * Clippy fix. * Enable the target point by default. * Add to the readme.
This commit is contained in:
@ -104,11 +104,11 @@ struct Args {
|
||||
#[arg(long)]
|
||||
generate_masks: bool,
|
||||
|
||||
#[arg(long)]
|
||||
point_x: Option<f64>,
|
||||
#[arg(long, default_value_t = 0.5)]
|
||||
point_x: f64,
|
||||
|
||||
#[arg(long)]
|
||||
point_y: Option<f64>,
|
||||
#[arg(long, default_value_t = 0.5)]
|
||||
point_y: f64,
|
||||
}
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
@ -135,7 +135,7 @@ pub fn main() -> anyhow::Result<()> {
|
||||
let (_c, h, w) = image.dims3()?;
|
||||
(image, h, w)
|
||||
} else {
|
||||
let (image, h, w) = candle_examples::load_image(args.image, Some(model_sam::IMAGE_SIZE))?;
|
||||
let (image, h, w) = candle_examples::load_image(&args.image, Some(model_sam::IMAGE_SIZE))?;
|
||||
(image.to_device(&device)?, h, w)
|
||||
};
|
||||
println!("loaded image {image:?}");
|
||||
@ -163,7 +163,7 @@ pub fn main() -> anyhow::Result<()> {
|
||||
/* crop_n_points_downscale_factor */ 1,
|
||||
)?
|
||||
} else {
|
||||
let point = args.point_x.zip(args.point_y);
|
||||
let point = Some((args.point_x, args.point_y));
|
||||
let (mask, iou_predictions) = sam.forward(&image, point, false)?;
|
||||
println!("mask:\n{mask}");
|
||||
println!("iou_predictions: {iou_predictions:?}");
|
||||
@ -174,9 +174,50 @@ pub fn main() -> anyhow::Result<()> {
|
||||
let mask = mask.expand((3, h, w))?;
|
||||
candle_examples::save_image_resize(&mask, "sam_mask.png", initial_h, initial_w)?;
|
||||
|
||||
let image = sam.preprocess(&image)?;
|
||||
let image = sam.unpreprocess(&image)?.to_dtype(DType::U8)?;
|
||||
candle_examples::save_image(&image, "sam_input_scaled.png")?;
|
||||
if !args.image.ends_with(".safetensors") {
|
||||
let mut img = image::io::Reader::open(&args.image)?
|
||||
.decode()
|
||||
.map_err(candle::Error::wrap)?;
|
||||
let mask_pixels = mask.permute((1, 2, 0))?.flatten_all()?.to_vec1::<u8>()?;
|
||||
let mask_img: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
|
||||
match image::ImageBuffer::from_raw(w as u32, h as u32, mask_pixels) {
|
||||
Some(image) => image,
|
||||
None => anyhow::bail!("error saving merged image"),
|
||||
};
|
||||
let mask_img = image::DynamicImage::from(mask_img).resize_to_fill(
|
||||
img.width(),
|
||||
img.height(),
|
||||
image::imageops::FilterType::CatmullRom,
|
||||
);
|
||||
for x in 0..img.width() {
|
||||
for y in 0..img.height() {
|
||||
let mask_p = imageproc::drawing::Canvas::get_pixel(&mask_img, x, y);
|
||||
if mask_p.0[0] > 100 {
|
||||
let mut img_p = imageproc::drawing::Canvas::get_pixel(&img, x, y);
|
||||
img_p.0[2] = 255 - (255 - img_p.0[2]) / 2;
|
||||
img_p.0[1] /= 2;
|
||||
img_p.0[0] /= 2;
|
||||
imageproc::drawing::Canvas::draw_pixel(&mut img, x, y, img_p)
|
||||
}
|
||||
}
|
||||
}
|
||||
match point {
|
||||
Some((x, y)) => {
|
||||
let (x, y) = (
|
||||
(x * img.width() as f64) as i32,
|
||||
(y * img.height() as f64) as i32,
|
||||
);
|
||||
imageproc::drawing::draw_filled_circle(
|
||||
&img,
|
||||
(x, y),
|
||||
3,
|
||||
image::Rgba([255, 0, 0, 200]),
|
||||
)
|
||||
.save("sam_merged.jpg")?
|
||||
}
|
||||
None => img.save("sam_merged.jpg")?,
|
||||
};
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
Reference in New Issue
Block a user