Segment Anything readme (#827)

* Add a readme for the segment-anything model.

* Add the original image.

* Clean-up the segment anything cli example.

* Also print the mask id in the outputs.
This commit is contained in:
Laurent Mazare
2023-09-12 14:35:55 +01:00
committed by GitHub
parent 25aacda28e
commit 42da17694a
4 changed files with 81 additions and 67 deletions

View File

@ -27,12 +27,19 @@ struct Args {
#[arg(long)]
generate_masks: bool,
/// The target point x coordinate, between 0 and 1 (0.5 is at the middle of the image).
#[arg(long, default_value_t = 0.5)]
point_x: f64,
/// The target point y coordinate, between 0 and 1 (0.5 is at the middle of the image).
#[arg(long, default_value_t = 0.5)]
point_y: f64,
/// The detection threshold for the mask, 0 is the default value, negative values mean a larger
/// mask, positive makes the mask more selective.
#[arg(long, default_value_t = 0.)]
threshold: f32,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
@ -57,28 +64,9 @@ pub fn main() -> anyhow::Result<()> {
let device = candle_examples::device(args.cpu)?;
let (image, initial_h, initial_w) = if args.image.ends_with(".safetensors") {
let mut tensors = candle::safetensors::load(&args.image, &device)?;
let image = match tensors.remove("image") {
Some(image) => image,
None => {
if tensors.len() != 1 {
anyhow::bail!("multiple tensors in '{}'", args.image)
}
tensors.into_values().next().unwrap()
}
};
let image = if image.rank() == 4 {
image.get(0)?
} else {
image
};
let (_c, h, w) = image.dims3()?;
(image, h, w)
} else {
let (image, h, w) = candle_examples::load_image(&args.image, Some(sam::IMAGE_SIZE))?;
(image.to_device(&device)?, h, w)
};
let (image, initial_h, initial_w) =
candle_examples::load_image(&args.image, Some(sam::IMAGE_SIZE))?;
let image = image.to_device(&device)?;
println!("loaded image {image:?}");
let model = match args.model {
@ -113,7 +101,7 @@ pub fn main() -> anyhow::Result<()> {
/* crop_n_points_downscale_factor */ 1,
)?;
for (idx, bbox) in bboxes.iter().enumerate() {
println!("{bbox:?}");
println!("{idx} {bbox:?}");
let mask = (&bbox.data.to_dtype(DType::U8)? * 255.)?;
let (h, w) = mask.dims2()?;
let mask = mask.broadcast_as((3, h, w))?;
@ -135,56 +123,42 @@ pub fn main() -> anyhow::Result<()> {
println!("mask:\n{mask}");
println!("iou_predictions: {iou_predictions:?}");
// Save the mask as an image.
let mask = (mask.ge(0f32)? * 255.)?;
let mask = (mask.ge(args.threshold)? * 255.)?;
let (_one, h, w) = mask.dims3()?;
let mask = mask.expand((3, h, w))?;
candle_examples::save_image_resize(&mask, "sam_mask.png", initial_h, initial_w)?;
if !args.image.ends_with(".safetensors") {
let mut img = image::io::Reader::open(&args.image)?
.decode()
.map_err(candle::Error::wrap)?;
let mask_pixels = mask.permute((1, 2, 0))?.flatten_all()?.to_vec1::<u8>()?;
let mask_img: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
match image::ImageBuffer::from_raw(w as u32, h as u32, mask_pixels) {
Some(image) => image,
None => anyhow::bail!("error saving merged image"),
};
let mask_img = image::DynamicImage::from(mask_img).resize_to_fill(
img.width(),
img.height(),
image::imageops::FilterType::CatmullRom,
);
for x in 0..img.width() {
for y in 0..img.height() {
let mask_p = imageproc::drawing::Canvas::get_pixel(&mask_img, x, y);
if mask_p.0[0] > 100 {
let mut img_p = imageproc::drawing::Canvas::get_pixel(&img, x, y);
img_p.0[2] = 255 - (255 - img_p.0[2]) / 2;
img_p.0[1] /= 2;
img_p.0[0] /= 2;
imageproc::drawing::Canvas::draw_pixel(&mut img, x, y, img_p)
}
let mut img = image::io::Reader::open(&args.image)?
.decode()
.map_err(candle::Error::wrap)?;
let mask_pixels = mask.permute((1, 2, 0))?.flatten_all()?.to_vec1::<u8>()?;
let mask_img: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
match image::ImageBuffer::from_raw(w as u32, h as u32, mask_pixels) {
Some(image) => image,
None => anyhow::bail!("error saving merged image"),
};
let mask_img = image::DynamicImage::from(mask_img).resize_to_fill(
img.width(),
img.height(),
image::imageops::FilterType::CatmullRom,
);
for x in 0..img.width() {
for y in 0..img.height() {
let mask_p = imageproc::drawing::Canvas::get_pixel(&mask_img, x, y);
if mask_p.0[0] > 100 {
let mut img_p = imageproc::drawing::Canvas::get_pixel(&img, x, y);
img_p.0[2] = 255 - (255 - img_p.0[2]) / 2;
img_p.0[1] /= 2;
img_p.0[0] /= 2;
imageproc::drawing::Canvas::draw_pixel(&mut img, x, y, img_p)
}
}
match point {
Some((x, y)) => {
let (x, y) = (
(x * img.width() as f64) as i32,
(y * img.height() as f64) as i32,
);
imageproc::drawing::draw_filled_circle(
&img,
(x, y),
3,
image::Rgba([255, 0, 0, 200]),
)
.save("sam_merged.jpg")?
}
None => img.save("sam_merged.jpg")?,
};
}
let (x, y) = (
(args.point_x * img.width() as f64) as i32,
(args.point_y * img.height() as f64) as i32,
);
imageproc::drawing::draw_filled_circle(&img, (x, y), 3, image::Rgba([255, 0, 0, 200]))
.save("sam_merged.jpg")?
}
Ok(())
}