mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Segment Anything readme (#827)
* Add a readme for the segment-anything model. * Add the original image. * Clean-up the segment anything cli example. * Also print the mask id in the outputs.
This commit is contained in:
40
candle-examples/examples/segment-anything/README.md
Normal file
40
candle-examples/examples/segment-anything/README.md
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
# candle-segment-anything: Segment-Anything Model
|
||||||
|
|
||||||
|
This example is based on Meta AI [Segment-Anything
|
||||||
|
Model](https://github.com/facebookresearch/segment-anything). This model
|
||||||
|
provides a robust and fast image segmentation pipeline that can be tweaked via
|
||||||
|
some prompting (requesting some points to be in the target mask, requesting some
|
||||||
|
points to be part of the background so _not_ in the target mask, specifying some
|
||||||
|
bounding box).
|
||||||
|
|
||||||
|
The default backbone can be replaced by the smaller and faster TinyViT model
|
||||||
|
based on [MobileSAM](https://github.com/ChaoningZhang/MobileSAM).
|
||||||
|
|
||||||
|
## Running some example.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo run --example segment-anything --release -- \
|
||||||
|
--image candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||||
|
--use-tiny
|
||||||
|
--point-x 0.4
|
||||||
|
--point-y 0.3
|
||||||
|
```
|
||||||
|
|
||||||
|
Running this command generates a `sam_merged.jpg` file containing the original
|
||||||
|
image with a blue overlay of the selected mask. The red dot represents the prompt
|
||||||
|
specified by `--point-x 0.4 --point-y 0.3`, this prompt is assumed to be part
|
||||||
|
of the target mask.
|
||||||
|
|
||||||
|
The values used for `--point-x` and `--point-y` should be between 0 and 1 and
|
||||||
|
are proportional to the image dimension, i.e. use 0.5 for the image center.
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
### Command-line flags
|
||||||
|
- `--use-tiny`: use the TinyViT based MobileSAM backbone rather than the default
|
||||||
|
one.
|
||||||
|
- `--point-x`, `--point-y`: specifies the location of the target point.
|
||||||
|
- `--threshold`: sets the threshold value to be part of the mask, a negative
|
||||||
|
value results in a larger mask and can be specified via `--threshold=-1.2`.
|
BIN
candle-examples/examples/segment-anything/assets/sam_merged.jpg
Normal file
BIN
candle-examples/examples/segment-anything/assets/sam_merged.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 157 KiB |
@ -27,12 +27,19 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
generate_masks: bool,
|
generate_masks: bool,
|
||||||
|
|
||||||
|
/// The target point x coordinate, between 0 and 1 (0.5 is at the middle of the image).
|
||||||
#[arg(long, default_value_t = 0.5)]
|
#[arg(long, default_value_t = 0.5)]
|
||||||
point_x: f64,
|
point_x: f64,
|
||||||
|
|
||||||
|
/// The target point y coordinate, between 0 and 1 (0.5 is at the middle of the image).
|
||||||
#[arg(long, default_value_t = 0.5)]
|
#[arg(long, default_value_t = 0.5)]
|
||||||
point_y: f64,
|
point_y: f64,
|
||||||
|
|
||||||
|
/// The detection threshold for the mask, 0 is the default value, negative values mean a larger
|
||||||
|
/// mask, positive makes the mask more selective.
|
||||||
|
#[arg(long, default_value_t = 0.)]
|
||||||
|
threshold: f32,
|
||||||
|
|
||||||
/// Enable tracing (generates a trace-timestamp.json file).
|
/// Enable tracing (generates a trace-timestamp.json file).
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
tracing: bool,
|
tracing: bool,
|
||||||
@ -57,28 +64,9 @@ pub fn main() -> anyhow::Result<()> {
|
|||||||
|
|
||||||
let device = candle_examples::device(args.cpu)?;
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
|
||||||
let (image, initial_h, initial_w) = if args.image.ends_with(".safetensors") {
|
let (image, initial_h, initial_w) =
|
||||||
let mut tensors = candle::safetensors::load(&args.image, &device)?;
|
candle_examples::load_image(&args.image, Some(sam::IMAGE_SIZE))?;
|
||||||
let image = match tensors.remove("image") {
|
let image = image.to_device(&device)?;
|
||||||
Some(image) => image,
|
|
||||||
None => {
|
|
||||||
if tensors.len() != 1 {
|
|
||||||
anyhow::bail!("multiple tensors in '{}'", args.image)
|
|
||||||
}
|
|
||||||
tensors.into_values().next().unwrap()
|
|
||||||
}
|
|
||||||
};
|
|
||||||
let image = if image.rank() == 4 {
|
|
||||||
image.get(0)?
|
|
||||||
} else {
|
|
||||||
image
|
|
||||||
};
|
|
||||||
let (_c, h, w) = image.dims3()?;
|
|
||||||
(image, h, w)
|
|
||||||
} else {
|
|
||||||
let (image, h, w) = candle_examples::load_image(&args.image, Some(sam::IMAGE_SIZE))?;
|
|
||||||
(image.to_device(&device)?, h, w)
|
|
||||||
};
|
|
||||||
println!("loaded image {image:?}");
|
println!("loaded image {image:?}");
|
||||||
|
|
||||||
let model = match args.model {
|
let model = match args.model {
|
||||||
@ -113,7 +101,7 @@ pub fn main() -> anyhow::Result<()> {
|
|||||||
/* crop_n_points_downscale_factor */ 1,
|
/* crop_n_points_downscale_factor */ 1,
|
||||||
)?;
|
)?;
|
||||||
for (idx, bbox) in bboxes.iter().enumerate() {
|
for (idx, bbox) in bboxes.iter().enumerate() {
|
||||||
println!("{bbox:?}");
|
println!("{idx} {bbox:?}");
|
||||||
let mask = (&bbox.data.to_dtype(DType::U8)? * 255.)?;
|
let mask = (&bbox.data.to_dtype(DType::U8)? * 255.)?;
|
||||||
let (h, w) = mask.dims2()?;
|
let (h, w) = mask.dims2()?;
|
||||||
let mask = mask.broadcast_as((3, h, w))?;
|
let mask = mask.broadcast_as((3, h, w))?;
|
||||||
@ -135,56 +123,42 @@ pub fn main() -> anyhow::Result<()> {
|
|||||||
println!("mask:\n{mask}");
|
println!("mask:\n{mask}");
|
||||||
println!("iou_predictions: {iou_predictions:?}");
|
println!("iou_predictions: {iou_predictions:?}");
|
||||||
|
|
||||||
// Save the mask as an image.
|
let mask = (mask.ge(args.threshold)? * 255.)?;
|
||||||
let mask = (mask.ge(0f32)? * 255.)?;
|
|
||||||
let (_one, h, w) = mask.dims3()?;
|
let (_one, h, w) = mask.dims3()?;
|
||||||
let mask = mask.expand((3, h, w))?;
|
let mask = mask.expand((3, h, w))?;
|
||||||
candle_examples::save_image_resize(&mask, "sam_mask.png", initial_h, initial_w)?;
|
|
||||||
|
|
||||||
if !args.image.ends_with(".safetensors") {
|
let mut img = image::io::Reader::open(&args.image)?
|
||||||
let mut img = image::io::Reader::open(&args.image)?
|
.decode()
|
||||||
.decode()
|
.map_err(candle::Error::wrap)?;
|
||||||
.map_err(candle::Error::wrap)?;
|
let mask_pixels = mask.permute((1, 2, 0))?.flatten_all()?.to_vec1::<u8>()?;
|
||||||
let mask_pixels = mask.permute((1, 2, 0))?.flatten_all()?.to_vec1::<u8>()?;
|
let mask_img: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
|
||||||
let mask_img: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
|
match image::ImageBuffer::from_raw(w as u32, h as u32, mask_pixels) {
|
||||||
match image::ImageBuffer::from_raw(w as u32, h as u32, mask_pixels) {
|
Some(image) => image,
|
||||||
Some(image) => image,
|
None => anyhow::bail!("error saving merged image"),
|
||||||
None => anyhow::bail!("error saving merged image"),
|
};
|
||||||
};
|
let mask_img = image::DynamicImage::from(mask_img).resize_to_fill(
|
||||||
let mask_img = image::DynamicImage::from(mask_img).resize_to_fill(
|
img.width(),
|
||||||
img.width(),
|
img.height(),
|
||||||
img.height(),
|
image::imageops::FilterType::CatmullRom,
|
||||||
image::imageops::FilterType::CatmullRom,
|
);
|
||||||
);
|
for x in 0..img.width() {
|
||||||
for x in 0..img.width() {
|
for y in 0..img.height() {
|
||||||
for y in 0..img.height() {
|
let mask_p = imageproc::drawing::Canvas::get_pixel(&mask_img, x, y);
|
||||||
let mask_p = imageproc::drawing::Canvas::get_pixel(&mask_img, x, y);
|
if mask_p.0[0] > 100 {
|
||||||
if mask_p.0[0] > 100 {
|
let mut img_p = imageproc::drawing::Canvas::get_pixel(&img, x, y);
|
||||||
let mut img_p = imageproc::drawing::Canvas::get_pixel(&img, x, y);
|
img_p.0[2] = 255 - (255 - img_p.0[2]) / 2;
|
||||||
img_p.0[2] = 255 - (255 - img_p.0[2]) / 2;
|
img_p.0[1] /= 2;
|
||||||
img_p.0[1] /= 2;
|
img_p.0[0] /= 2;
|
||||||
img_p.0[0] /= 2;
|
imageproc::drawing::Canvas::draw_pixel(&mut img, x, y, img_p)
|
||||||
imageproc::drawing::Canvas::draw_pixel(&mut img, x, y, img_p)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
match point {
|
|
||||||
Some((x, y)) => {
|
|
||||||
let (x, y) = (
|
|
||||||
(x * img.width() as f64) as i32,
|
|
||||||
(y * img.height() as f64) as i32,
|
|
||||||
);
|
|
||||||
imageproc::drawing::draw_filled_circle(
|
|
||||||
&img,
|
|
||||||
(x, y),
|
|
||||||
3,
|
|
||||||
image::Rgba([255, 0, 0, 200]),
|
|
||||||
)
|
|
||||||
.save("sam_merged.jpg")?
|
|
||||||
}
|
|
||||||
None => img.save("sam_merged.jpg")?,
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
|
let (x, y) = (
|
||||||
|
(args.point_x * img.width() as f64) as i32,
|
||||||
|
(args.point_y * img.height() as f64) as i32,
|
||||||
|
);
|
||||||
|
imageproc::drawing::draw_filled_circle(&img, (x, y), 3, image::Rgba([255, 0, 0, 200]))
|
||||||
|
.save("sam_merged.jpg")?
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
Binary file not shown.
Before Width: | Height: | Size: 80 KiB |
Reference in New Issue
Block a user