From e5703d2f56ce24652e7ae85dc74484681e4dbcb9 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Fri, 8 Sep 2023 14:04:34 +0100 Subject: [PATCH] Draw the mask on a merged image. (#775) * Draw the mask on a merged image. * Clippy fix. * Enable the target point by default. * Add to the readme. --- README.md | 3 + .../examples/segment-anything/main.rs | 59 ++++++++++++++++--- 2 files changed, 53 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 076f2363..9e5f938a 100644 --- a/README.md +++ b/README.md @@ -64,6 +64,8 @@ Check out our [examples](./candle-examples/examples/): - [yolo-v3](./candle-examples/examples/yolo-v3/) and [yolo-v8](./candle-examples/examples/yolo-v8/): object detection and pose estimation models. + [segment-anything](./candle-examples/examples/segment-anything/): image + segmentation model with prompt. Run them using the following commands: ``` cargo run --example whisper --release @@ -76,6 +78,7 @@ cargo run --example dinov2 --release -- --image path/to/myinput.jpg cargo run --example quantized --release cargo run --example yolo-v3 --release -- myimage.jpg cargo run --example yolo-v8 --release -- myimage.jpg # for pose estimation, add --task pose +cargo run --example segment-anything --release -- --image myimage.jpg ``` In order to use **CUDA** add `--features cuda` to the example command line. If diff --git a/candle-examples/examples/segment-anything/main.rs b/candle-examples/examples/segment-anything/main.rs index 0f0c0482..c5095c0e 100644 --- a/candle-examples/examples/segment-anything/main.rs +++ b/candle-examples/examples/segment-anything/main.rs @@ -104,11 +104,11 @@ struct Args { #[arg(long)] generate_masks: bool, - #[arg(long)] - point_x: Option, + #[arg(long, default_value_t = 0.5)] + point_x: f64, - #[arg(long)] - point_y: Option, + #[arg(long, default_value_t = 0.5)] + point_y: f64, } pub fn main() -> anyhow::Result<()> { @@ -135,7 +135,7 @@ pub fn main() -> anyhow::Result<()> { let (_c, h, w) = image.dims3()?; (image, h, w) } else { - let (image, h, w) = candle_examples::load_image(args.image, Some(model_sam::IMAGE_SIZE))?; + let (image, h, w) = candle_examples::load_image(&args.image, Some(model_sam::IMAGE_SIZE))?; (image.to_device(&device)?, h, w) }; println!("loaded image {image:?}"); @@ -163,7 +163,7 @@ pub fn main() -> anyhow::Result<()> { /* crop_n_points_downscale_factor */ 1, )? } else { - let point = args.point_x.zip(args.point_y); + let point = Some((args.point_x, args.point_y)); let (mask, iou_predictions) = sam.forward(&image, point, false)?; println!("mask:\n{mask}"); println!("iou_predictions: {iou_predictions:?}"); @@ -174,9 +174,50 @@ pub fn main() -> anyhow::Result<()> { let mask = mask.expand((3, h, w))?; candle_examples::save_image_resize(&mask, "sam_mask.png", initial_h, initial_w)?; - let image = sam.preprocess(&image)?; - let image = sam.unpreprocess(&image)?.to_dtype(DType::U8)?; - candle_examples::save_image(&image, "sam_input_scaled.png")?; + if !args.image.ends_with(".safetensors") { + let mut img = image::io::Reader::open(&args.image)? + .decode() + .map_err(candle::Error::wrap)?; + let mask_pixels = mask.permute((1, 2, 0))?.flatten_all()?.to_vec1::()?; + let mask_img: image::ImageBuffer, Vec> = + match image::ImageBuffer::from_raw(w as u32, h as u32, mask_pixels) { + Some(image) => image, + None => anyhow::bail!("error saving merged image"), + }; + let mask_img = image::DynamicImage::from(mask_img).resize_to_fill( + img.width(), + img.height(), + image::imageops::FilterType::CatmullRom, + ); + for x in 0..img.width() { + for y in 0..img.height() { + let mask_p = imageproc::drawing::Canvas::get_pixel(&mask_img, x, y); + if mask_p.0[0] > 100 { + let mut img_p = imageproc::drawing::Canvas::get_pixel(&img, x, y); + img_p.0[2] = 255 - (255 - img_p.0[2]) / 2; + img_p.0[1] /= 2; + img_p.0[0] /= 2; + imageproc::drawing::Canvas::draw_pixel(&mut img, x, y, img_p) + } + } + } + match point { + Some((x, y)) => { + let (x, y) = ( + (x * img.width() as f64) as i32, + (y * img.height() as f64) as i32, + ); + imageproc::drawing::draw_filled_circle( + &img, + (x, y), + 3, + image::Rgba([255, 0, 0, 200]), + ) + .save("sam_merged.jpg")? + } + None => img.save("sam_merged.jpg")?, + }; + } } Ok(()) }