Draw the mask on a merged image. (#775)

* Draw the mask on a merged image.

* Clippy fix.

* Enable the target point by default.

* Add to the readme.
This commit is contained in:
Laurent Mazare
2023-09-08 14:04:34 +01:00
committed by GitHub
parent 98172d46fa
commit e5703d2f56
2 changed files with 53 additions and 9 deletions

View File

@ -64,6 +64,8 @@ Check out our [examples](./candle-examples/examples/):
- [yolo-v3](./candle-examples/examples/yolo-v3/) and - [yolo-v3](./candle-examples/examples/yolo-v3/) and
[yolo-v8](./candle-examples/examples/yolo-v8/): object detection and pose [yolo-v8](./candle-examples/examples/yolo-v8/): object detection and pose
estimation models. estimation models.
[segment-anything](./candle-examples/examples/segment-anything/): image
segmentation model with prompt.
Run them using the following commands: Run them using the following commands:
``` ```
cargo run --example whisper --release cargo run --example whisper --release
@ -76,6 +78,7 @@ cargo run --example dinov2 --release -- --image path/to/myinput.jpg
cargo run --example quantized --release cargo run --example quantized --release
cargo run --example yolo-v3 --release -- myimage.jpg cargo run --example yolo-v3 --release -- myimage.jpg
cargo run --example yolo-v8 --release -- myimage.jpg # for pose estimation, add --task pose cargo run --example yolo-v8 --release -- myimage.jpg # for pose estimation, add --task pose
cargo run --example segment-anything --release -- --image myimage.jpg
``` ```
In order to use **CUDA** add `--features cuda` to the example command line. If In order to use **CUDA** add `--features cuda` to the example command line. If

View File

@ -104,11 +104,11 @@ struct Args {
#[arg(long)] #[arg(long)]
generate_masks: bool, generate_masks: bool,
#[arg(long)] #[arg(long, default_value_t = 0.5)]
point_x: Option<f64>, point_x: f64,
#[arg(long)] #[arg(long, default_value_t = 0.5)]
point_y: Option<f64>, point_y: f64,
} }
pub fn main() -> anyhow::Result<()> { pub fn main() -> anyhow::Result<()> {
@ -135,7 +135,7 @@ pub fn main() -> anyhow::Result<()> {
let (_c, h, w) = image.dims3()?; let (_c, h, w) = image.dims3()?;
(image, h, w) (image, h, w)
} else { } else {
let (image, h, w) = candle_examples::load_image(args.image, Some(model_sam::IMAGE_SIZE))?; let (image, h, w) = candle_examples::load_image(&args.image, Some(model_sam::IMAGE_SIZE))?;
(image.to_device(&device)?, h, w) (image.to_device(&device)?, h, w)
}; };
println!("loaded image {image:?}"); println!("loaded image {image:?}");
@ -163,7 +163,7 @@ pub fn main() -> anyhow::Result<()> {
/* crop_n_points_downscale_factor */ 1, /* crop_n_points_downscale_factor */ 1,
)? )?
} else { } else {
let point = args.point_x.zip(args.point_y); let point = Some((args.point_x, args.point_y));
let (mask, iou_predictions) = sam.forward(&image, point, false)?; let (mask, iou_predictions) = sam.forward(&image, point, false)?;
println!("mask:\n{mask}"); println!("mask:\n{mask}");
println!("iou_predictions: {iou_predictions:?}"); println!("iou_predictions: {iou_predictions:?}");
@ -174,9 +174,50 @@ pub fn main() -> anyhow::Result<()> {
let mask = mask.expand((3, h, w))?; let mask = mask.expand((3, h, w))?;
candle_examples::save_image_resize(&mask, "sam_mask.png", initial_h, initial_w)?; candle_examples::save_image_resize(&mask, "sam_mask.png", initial_h, initial_w)?;
let image = sam.preprocess(&image)?; if !args.image.ends_with(".safetensors") {
let image = sam.unpreprocess(&image)?.to_dtype(DType::U8)?; let mut img = image::io::Reader::open(&args.image)?
candle_examples::save_image(&image, "sam_input_scaled.png")?; .decode()
.map_err(candle::Error::wrap)?;
let mask_pixels = mask.permute((1, 2, 0))?.flatten_all()?.to_vec1::<u8>()?;
let mask_img: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
match image::ImageBuffer::from_raw(w as u32, h as u32, mask_pixels) {
Some(image) => image,
None => anyhow::bail!("error saving merged image"),
};
let mask_img = image::DynamicImage::from(mask_img).resize_to_fill(
img.width(),
img.height(),
image::imageops::FilterType::CatmullRom,
);
for x in 0..img.width() {
for y in 0..img.height() {
let mask_p = imageproc::drawing::Canvas::get_pixel(&mask_img, x, y);
if mask_p.0[0] > 100 {
let mut img_p = imageproc::drawing::Canvas::get_pixel(&img, x, y);
img_p.0[2] = 255 - (255 - img_p.0[2]) / 2;
img_p.0[1] /= 2;
img_p.0[0] /= 2;
imageproc::drawing::Canvas::draw_pixel(&mut img, x, y, img_p)
}
}
}
match point {
Some((x, y)) => {
let (x, y) = (
(x * img.width() as f64) as i32,
(y * img.height() as f64) as i32,
);
imageproc::drawing::draw_filled_circle(
&img,
(x, y),
3,
image::Rgba([255, 0, 0, 200]),
)
.save("sam_merged.jpg")?
}
None => img.save("sam_merged.jpg")?,
};
}
} }
Ok(()) Ok(())
} }