diff --git a/candle-examples/examples/segment-anything/README.md b/candle-examples/examples/segment-anything/README.md index 1b255bd2..da27f6ce 100644 --- a/candle-examples/examples/segment-anything/README.md +++ b/candle-examples/examples/segment-anything/README.md @@ -16,30 +16,29 @@ based on [MobileSAM](https://github.com/ChaoningZhang/MobileSAM). cargo run --example segment-anything --release -- \ --image candle-examples/examples/yolo-v8/assets/bike.jpg --use-tiny - --point-x 0.6,0.6 - --point-y 0.6,0.55 + --point 0.6,0.6 --point 0.6,0.55 ``` Running this command generates a `sam_merged.jpg` file containing the original image with a blue overlay of the selected mask. The red dots represent the prompt -specified by `--point-x 0.6,0.6 --point-y 0.6,0.55`, this prompt is assumed to be part +specified by `--point 0.6,0.6 --point 0.6,0.55`, this prompt is assumed to be part of the target mask. -The values used for `--point-x` and `--point-y` should be between 0 and 1 and -are proportional to the image dimension, i.e. use 0.5 for the image center. +The values used for `--point` should be a comma delimited pair of float values. +They are proportional to the image dimension, i.e. use 0.5 for the image center. Original image: ![Leading group, Giro d'Italia 2021](../yolo-v8/assets/bike.jpg) -Segment results by prompting with a single point `--point-x 0.6 --point-y 0.55`: +Segment results by prompting with a single point `--point 0.6,0.55`: ![Leading group, Giro d'Italia 2021](./assets/single_pt_prompt.jpg) -Segment results by prompting with multiple points `--point-x 0.6,0.6 --point-y 0.6,0.55`: +Segment results by prompting with multiple points `--point 0.6,0.6 --point 0.6,0.55`: ![Leading group, Giro d'Italia 2021](./assets/two_pt_prompt.jpg) ### Command-line flags - `--use-tiny`: use the TinyViT based MobileSAM backbone rather than the default one. -- `--point-x`, `--point-y`: specifies the location of the target points. +- `--point`: specifies the location of the target points. - `--threshold`: sets the threshold value to be part of the mask, a negative value results in a larger mask and can be specified via `--threshold=-1.2`. diff --git a/candle-examples/examples/segment-anything/main.rs b/candle-examples/examples/segment-anything/main.rs index 1eaf4b7c..bc586817 100644 --- a/candle-examples/examples/segment-anything/main.rs +++ b/candle-examples/examples/segment-anything/main.rs @@ -27,13 +27,9 @@ struct Args { #[arg(long)] generate_masks: bool, - /// Comma separated list of x coordinates, between 0 and 1 (0.5 is at the middle of the image). - #[arg(long, use_value_delimiter = true)] - point_x: Vec, - - /// Comma separated list of y coordinate, between 0 and 1 (0.5 is at the middle of the image). - #[arg(long, use_value_delimiter = true)] - point_y: Vec, + /// List of x,y coordinates, between 0 and 1 (0.5 is at the middle of the image). + #[arg(long)] + point: Vec, /// The detection threshold for the mask, 0 is the default value, negative values mean a larger /// mask, positive makes the mask more selective. @@ -111,14 +107,18 @@ pub fn main() -> anyhow::Result<()> { )?; } } else { - if args.point_x.len() != args.point_y.len() { - anyhow::bail!( - "number of x coordinates unequal to the number of y coordinates: {} v.s. {}", - args.point_x.len(), - args.point_y.len() - ); - } - let points: Vec<(f64, f64)> = args.point_x.into_iter().zip(args.point_y).collect(); + let points = args + .point + .iter() + .map(|point| { + use std::str::FromStr; + let xy = point.split(',').collect::>(); + if xy.len() != 2 { + anyhow::bail!("expected format for points is 0.4,0.2") + } + Ok((f64::from_str(xy[0])?, f64::from_str(xy[1])?)) + }) + .collect::>>()?; let start_time = std::time::Instant::now(); let (mask, iou_predictions) = sam.forward(&image, &points, false)?; println!( diff --git a/candle-transformers/src/models/segment_anything/sam.rs b/candle-transformers/src/models/segment_anything/sam.rs index 0c456b79..49e95adb 100644 --- a/candle-transformers/src/models/segment_anything/sam.rs +++ b/candle-transformers/src/models/segment_anything/sam.rs @@ -163,14 +163,15 @@ impl Sam { None } else { let n_points = points.len(); - let mut coords = vec![]; - points.iter().for_each(|(x, y)| { - let x = (*x as f32) * (original_w as f32); - let y = (*y as f32) * (original_h as f32); - coords.push(x); - coords.push(y); - }); - let points = Tensor::from_vec(coords, (n_points, 1, 2), img_embeddings.device())?; + let xys = points + .iter() + .flat_map(|(x, y)| { + let x = (*x as f32) * (original_w as f32); + let y = (*y as f32) * (original_h as f32); + [x, y] + }) + .collect::>(); + let points = Tensor::from_vec(xys, (n_points, 1, 2), img_embeddings.device())?; let labels = Tensor::ones((n_points, 1), DType::F32, img_embeddings.device())?; Some((points, labels)) };