mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Segment Anything readme (#827)
* Add a readme for the segment-anything model. * Add the original image. * Clean-up the segment anything cli example. * Also print the mask id in the outputs.
This commit is contained in:
40
candle-examples/examples/segment-anything/README.md
Normal file
40
candle-examples/examples/segment-anything/README.md
Normal file
@ -0,0 +1,40 @@
|
||||
# candle-segment-anything: Segment-Anything Model
|
||||
|
||||
This example is based on Meta AI [Segment-Anything
|
||||
Model](https://github.com/facebookresearch/segment-anything). This model
|
||||
provides a robust and fast image segmentation pipeline that can be tweaked via
|
||||
some prompting (requesting some points to be in the target mask, requesting some
|
||||
points to be part of the background so _not_ in the target mask, specifying some
|
||||
bounding box).
|
||||
|
||||
The default backbone can be replaced by the smaller and faster TinyViT model
|
||||
based on [MobileSAM](https://github.com/ChaoningZhang/MobileSAM).
|
||||
|
||||
## Running some example.
|
||||
|
||||
```bash
|
||||
cargo run --example segment-anything --release -- \
|
||||
--image candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||
--use-tiny
|
||||
--point-x 0.4
|
||||
--point-y 0.3
|
||||
```
|
||||
|
||||
Running this command generates a `sam_merged.jpg` file containing the original
|
||||
image with a blue overlay of the selected mask. The red dot represents the prompt
|
||||
specified by `--point-x 0.4 --point-y 0.3`, this prompt is assumed to be part
|
||||
of the target mask.
|
||||
|
||||
The values used for `--point-x` and `--point-y` should be between 0 and 1 and
|
||||
are proportional to the image dimension, i.e. use 0.5 for the image center.
|
||||
|
||||

|
||||
|
||||

|
||||
|
||||
### Command-line flags
|
||||
- `--use-tiny`: use the TinyViT based MobileSAM backbone rather than the default
|
||||
one.
|
||||
- `--point-x`, `--point-y`: specifies the location of the target point.
|
||||
- `--threshold`: sets the threshold value to be part of the mask, a negative
|
||||
value results in a larger mask and can be specified via `--threshold=-1.2`.
|
BIN
candle-examples/examples/segment-anything/assets/sam_merged.jpg
Normal file
BIN
candle-examples/examples/segment-anything/assets/sam_merged.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 157 KiB |
@ -27,12 +27,19 @@ struct Args {
|
||||
#[arg(long)]
|
||||
generate_masks: bool,
|
||||
|
||||
/// The target point x coordinate, between 0 and 1 (0.5 is at the middle of the image).
|
||||
#[arg(long, default_value_t = 0.5)]
|
||||
point_x: f64,
|
||||
|
||||
/// The target point y coordinate, between 0 and 1 (0.5 is at the middle of the image).
|
||||
#[arg(long, default_value_t = 0.5)]
|
||||
point_y: f64,
|
||||
|
||||
/// The detection threshold for the mask, 0 is the default value, negative values mean a larger
|
||||
/// mask, positive makes the mask more selective.
|
||||
#[arg(long, default_value_t = 0.)]
|
||||
threshold: f32,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
@ -57,28 +64,9 @@ pub fn main() -> anyhow::Result<()> {
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let (image, initial_h, initial_w) = if args.image.ends_with(".safetensors") {
|
||||
let mut tensors = candle::safetensors::load(&args.image, &device)?;
|
||||
let image = match tensors.remove("image") {
|
||||
Some(image) => image,
|
||||
None => {
|
||||
if tensors.len() != 1 {
|
||||
anyhow::bail!("multiple tensors in '{}'", args.image)
|
||||
}
|
||||
tensors.into_values().next().unwrap()
|
||||
}
|
||||
};
|
||||
let image = if image.rank() == 4 {
|
||||
image.get(0)?
|
||||
} else {
|
||||
image
|
||||
};
|
||||
let (_c, h, w) = image.dims3()?;
|
||||
(image, h, w)
|
||||
} else {
|
||||
let (image, h, w) = candle_examples::load_image(&args.image, Some(sam::IMAGE_SIZE))?;
|
||||
(image.to_device(&device)?, h, w)
|
||||
};
|
||||
let (image, initial_h, initial_w) =
|
||||
candle_examples::load_image(&args.image, Some(sam::IMAGE_SIZE))?;
|
||||
let image = image.to_device(&device)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let model = match args.model {
|
||||
@ -113,7 +101,7 @@ pub fn main() -> anyhow::Result<()> {
|
||||
/* crop_n_points_downscale_factor */ 1,
|
||||
)?;
|
||||
for (idx, bbox) in bboxes.iter().enumerate() {
|
||||
println!("{bbox:?}");
|
||||
println!("{idx} {bbox:?}");
|
||||
let mask = (&bbox.data.to_dtype(DType::U8)? * 255.)?;
|
||||
let (h, w) = mask.dims2()?;
|
||||
let mask = mask.broadcast_as((3, h, w))?;
|
||||
@ -135,56 +123,42 @@ pub fn main() -> anyhow::Result<()> {
|
||||
println!("mask:\n{mask}");
|
||||
println!("iou_predictions: {iou_predictions:?}");
|
||||
|
||||
// Save the mask as an image.
|
||||
let mask = (mask.ge(0f32)? * 255.)?;
|
||||
let mask = (mask.ge(args.threshold)? * 255.)?;
|
||||
let (_one, h, w) = mask.dims3()?;
|
||||
let mask = mask.expand((3, h, w))?;
|
||||
candle_examples::save_image_resize(&mask, "sam_mask.png", initial_h, initial_w)?;
|
||||
|
||||
if !args.image.ends_with(".safetensors") {
|
||||
let mut img = image::io::Reader::open(&args.image)?
|
||||
.decode()
|
||||
.map_err(candle::Error::wrap)?;
|
||||
let mask_pixels = mask.permute((1, 2, 0))?.flatten_all()?.to_vec1::<u8>()?;
|
||||
let mask_img: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
|
||||
match image::ImageBuffer::from_raw(w as u32, h as u32, mask_pixels) {
|
||||
Some(image) => image,
|
||||
None => anyhow::bail!("error saving merged image"),
|
||||
};
|
||||
let mask_img = image::DynamicImage::from(mask_img).resize_to_fill(
|
||||
img.width(),
|
||||
img.height(),
|
||||
image::imageops::FilterType::CatmullRom,
|
||||
);
|
||||
for x in 0..img.width() {
|
||||
for y in 0..img.height() {
|
||||
let mask_p = imageproc::drawing::Canvas::get_pixel(&mask_img, x, y);
|
||||
if mask_p.0[0] > 100 {
|
||||
let mut img_p = imageproc::drawing::Canvas::get_pixel(&img, x, y);
|
||||
img_p.0[2] = 255 - (255 - img_p.0[2]) / 2;
|
||||
img_p.0[1] /= 2;
|
||||
img_p.0[0] /= 2;
|
||||
imageproc::drawing::Canvas::draw_pixel(&mut img, x, y, img_p)
|
||||
}
|
||||
let mut img = image::io::Reader::open(&args.image)?
|
||||
.decode()
|
||||
.map_err(candle::Error::wrap)?;
|
||||
let mask_pixels = mask.permute((1, 2, 0))?.flatten_all()?.to_vec1::<u8>()?;
|
||||
let mask_img: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
|
||||
match image::ImageBuffer::from_raw(w as u32, h as u32, mask_pixels) {
|
||||
Some(image) => image,
|
||||
None => anyhow::bail!("error saving merged image"),
|
||||
};
|
||||
let mask_img = image::DynamicImage::from(mask_img).resize_to_fill(
|
||||
img.width(),
|
||||
img.height(),
|
||||
image::imageops::FilterType::CatmullRom,
|
||||
);
|
||||
for x in 0..img.width() {
|
||||
for y in 0..img.height() {
|
||||
let mask_p = imageproc::drawing::Canvas::get_pixel(&mask_img, x, y);
|
||||
if mask_p.0[0] > 100 {
|
||||
let mut img_p = imageproc::drawing::Canvas::get_pixel(&img, x, y);
|
||||
img_p.0[2] = 255 - (255 - img_p.0[2]) / 2;
|
||||
img_p.0[1] /= 2;
|
||||
img_p.0[0] /= 2;
|
||||
imageproc::drawing::Canvas::draw_pixel(&mut img, x, y, img_p)
|
||||
}
|
||||
}
|
||||
match point {
|
||||
Some((x, y)) => {
|
||||
let (x, y) = (
|
||||
(x * img.width() as f64) as i32,
|
||||
(y * img.height() as f64) as i32,
|
||||
);
|
||||
imageproc::drawing::draw_filled_circle(
|
||||
&img,
|
||||
(x, y),
|
||||
3,
|
||||
image::Rgba([255, 0, 0, 200]),
|
||||
)
|
||||
.save("sam_merged.jpg")?
|
||||
}
|
||||
None => img.save("sam_merged.jpg")?,
|
||||
};
|
||||
}
|
||||
let (x, y) = (
|
||||
(args.point_x * img.width() as f64) as i32,
|
||||
(args.point_y * img.height() as f64) as i32,
|
||||
);
|
||||
imageproc::drawing::draw_filled_circle(&img, (x, y), 3, image::Rgba([255, 0, 0, 200]))
|
||||
.save("sam_merged.jpg")?
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
Binary file not shown.
Before Width: | Height: | Size: 80 KiB |
Reference in New Issue
Block a user