Segment Anything readme (#827)

* Add a readme for the segment-anything model.

* Add the original image.

* Clean-up the segment anything cli example.

* Also print the mask id in the outputs.
This commit is contained in:
Laurent Mazare
2023-09-12 14:35:55 +01:00
committed by GitHub
parent 25aacda28e
commit 42da17694a
4 changed files with 81 additions and 67 deletions

View File

@ -0,0 +1,40 @@
# candle-segment-anything: Segment-Anything Model
This example is based on Meta AI [Segment-Anything
Model](https://github.com/facebookresearch/segment-anything). This model
provides a robust and fast image segmentation pipeline that can be tweaked via
some prompting (requesting some points to be in the target mask, requesting some
points to be part of the background so _not_ in the target mask, specifying some
bounding box).
The default backbone can be replaced by the smaller and faster TinyViT model
based on [MobileSAM](https://github.com/ChaoningZhang/MobileSAM).
## Running some example.
```bash
cargo run --example segment-anything --release -- \
--image candle-examples/examples/yolo-v8/assets/bike.jpg
--use-tiny
--point-x 0.4
--point-y 0.3
```
Running this command generates a `sam_merged.jpg` file containing the original
image with a blue overlay of the selected mask. The red dot represents the prompt
specified by `--point-x 0.4 --point-y 0.3`, this prompt is assumed to be part
of the target mask.
The values used for `--point-x` and `--point-y` should be between 0 and 1 and
are proportional to the image dimension, i.e. use 0.5 for the image center.
![Leading group, Giro d'Italia 2021](../yolo-v8/assets/bike.jpg)
![Leading group, Giro d'Italia 2021](./assets/sam_merged.jpg)
### Command-line flags
- `--use-tiny`: use the TinyViT based MobileSAM backbone rather than the default
one.
- `--point-x`, `--point-y`: specifies the location of the target point.
- `--threshold`: sets the threshold value to be part of the mask, a negative
value results in a larger mask and can be specified via `--threshold=-1.2`.

Binary file not shown.

After

Width:  |  Height:  |  Size: 157 KiB

View File

@ -27,12 +27,19 @@ struct Args {
#[arg(long)]
generate_masks: bool,
/// The target point x coordinate, between 0 and 1 (0.5 is at the middle of the image).
#[arg(long, default_value_t = 0.5)]
point_x: f64,
/// The target point y coordinate, between 0 and 1 (0.5 is at the middle of the image).
#[arg(long, default_value_t = 0.5)]
point_y: f64,
/// The detection threshold for the mask, 0 is the default value, negative values mean a larger
/// mask, positive makes the mask more selective.
#[arg(long, default_value_t = 0.)]
threshold: f32,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
@ -57,28 +64,9 @@ pub fn main() -> anyhow::Result<()> {
let device = candle_examples::device(args.cpu)?;
let (image, initial_h, initial_w) = if args.image.ends_with(".safetensors") {
let mut tensors = candle::safetensors::load(&args.image, &device)?;
let image = match tensors.remove("image") {
Some(image) => image,
None => {
if tensors.len() != 1 {
anyhow::bail!("multiple tensors in '{}'", args.image)
}
tensors.into_values().next().unwrap()
}
};
let image = if image.rank() == 4 {
image.get(0)?
} else {
image
};
let (_c, h, w) = image.dims3()?;
(image, h, w)
} else {
let (image, h, w) = candle_examples::load_image(&args.image, Some(sam::IMAGE_SIZE))?;
(image.to_device(&device)?, h, w)
};
let (image, initial_h, initial_w) =
candle_examples::load_image(&args.image, Some(sam::IMAGE_SIZE))?;
let image = image.to_device(&device)?;
println!("loaded image {image:?}");
let model = match args.model {
@ -113,7 +101,7 @@ pub fn main() -> anyhow::Result<()> {
/* crop_n_points_downscale_factor */ 1,
)?;
for (idx, bbox) in bboxes.iter().enumerate() {
println!("{bbox:?}");
println!("{idx} {bbox:?}");
let mask = (&bbox.data.to_dtype(DType::U8)? * 255.)?;
let (h, w) = mask.dims2()?;
let mask = mask.broadcast_as((3, h, w))?;
@ -135,56 +123,42 @@ pub fn main() -> anyhow::Result<()> {
println!("mask:\n{mask}");
println!("iou_predictions: {iou_predictions:?}");
// Save the mask as an image.
let mask = (mask.ge(0f32)? * 255.)?;
let mask = (mask.ge(args.threshold)? * 255.)?;
let (_one, h, w) = mask.dims3()?;
let mask = mask.expand((3, h, w))?;
candle_examples::save_image_resize(&mask, "sam_mask.png", initial_h, initial_w)?;
if !args.image.ends_with(".safetensors") {
let mut img = image::io::Reader::open(&args.image)?
.decode()
.map_err(candle::Error::wrap)?;
let mask_pixels = mask.permute((1, 2, 0))?.flatten_all()?.to_vec1::<u8>()?;
let mask_img: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
match image::ImageBuffer::from_raw(w as u32, h as u32, mask_pixels) {
Some(image) => image,
None => anyhow::bail!("error saving merged image"),
};
let mask_img = image::DynamicImage::from(mask_img).resize_to_fill(
img.width(),
img.height(),
image::imageops::FilterType::CatmullRom,
);
for x in 0..img.width() {
for y in 0..img.height() {
let mask_p = imageproc::drawing::Canvas::get_pixel(&mask_img, x, y);
if mask_p.0[0] > 100 {
let mut img_p = imageproc::drawing::Canvas::get_pixel(&img, x, y);
img_p.0[2] = 255 - (255 - img_p.0[2]) / 2;
img_p.0[1] /= 2;
img_p.0[0] /= 2;
imageproc::drawing::Canvas::draw_pixel(&mut img, x, y, img_p)
}
let mut img = image::io::Reader::open(&args.image)?
.decode()
.map_err(candle::Error::wrap)?;
let mask_pixels = mask.permute((1, 2, 0))?.flatten_all()?.to_vec1::<u8>()?;
let mask_img: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
match image::ImageBuffer::from_raw(w as u32, h as u32, mask_pixels) {
Some(image) => image,
None => anyhow::bail!("error saving merged image"),
};
let mask_img = image::DynamicImage::from(mask_img).resize_to_fill(
img.width(),
img.height(),
image::imageops::FilterType::CatmullRom,
);
for x in 0..img.width() {
for y in 0..img.height() {
let mask_p = imageproc::drawing::Canvas::get_pixel(&mask_img, x, y);
if mask_p.0[0] > 100 {
let mut img_p = imageproc::drawing::Canvas::get_pixel(&img, x, y);
img_p.0[2] = 255 - (255 - img_p.0[2]) / 2;
img_p.0[1] /= 2;
img_p.0[0] /= 2;
imageproc::drawing::Canvas::draw_pixel(&mut img, x, y, img_p)
}
}
match point {
Some((x, y)) => {
let (x, y) = (
(x * img.width() as f64) as i32,
(y * img.height() as f64) as i32,
);
imageproc::drawing::draw_filled_circle(
&img,
(x, y),
3,
image::Rgba([255, 0, 0, 200]),
)
.save("sam_merged.jpg")?
}
None => img.save("sam_merged.jpg")?,
};
}
let (x, y) = (
(args.point_x * img.width() as f64) as i32,
(args.point_y * img.height() as f64) as i32,
);
imageproc::drawing::draw_filled_circle(&img, (x, y), 3, image::Rgba([255, 0, 0, 200]))
.save("sam_merged.jpg")?
}
Ok(())
}

Binary file not shown.

Before

Width:  |  Height:  |  Size: 80 KiB