[segment-anything] Support multi-point as the prompt input (#945)

* [sam] Support multi-point prompts

* [segment-anything] Pass points by reference

* [segment-anything] Update example code and image

* Fix clippy lint.

---------

Co-authored-by: Yun Ding <yunding@nvidia.com>
Co-authored-by: laurent <laurent.mazare@gmail.com>
This commit is contained in:
GeauxEric
2023-09-25 04:14:10 -07:00
committed by GitHub
parent dc47224ab9
commit 7f2bbcf746
6 changed files with 55 additions and 34 deletions

View File

@ -130,7 +130,7 @@ impl Sam {
pub fn forward(
&self,
img: &Tensor,
point: Option<(f64, f64)>,
points: &[(f64, f64)],
multimask_output: bool,
) -> Result<(Tensor, Tensor)> {
let (_c, original_h, original_w) = img.dims3()?;
@ -140,7 +140,7 @@ impl Sam {
&img_embeddings,
original_h,
original_w,
point,
points,
multimask_output,
)?;
let mask = low_res_mask
@ -155,20 +155,24 @@ impl Sam {
img_embeddings: &Tensor,
original_h: usize,
original_w: usize,
point: Option<(f64, f64)>,
points: &[(f64, f64)],
multimask_output: bool,
) -> Result<(Tensor, Tensor)> {
let image_pe = self.prompt_encoder.get_dense_pe()?;
let points = match point {
None => None,
Some((x, y)) => {
let points = Tensor::new(
&[[[x as f32 * original_w as f32, y as f32 * original_h as f32]]],
img_embeddings.device(),
)?;
let labels = Tensor::ones((1, 1), DType::F32, img_embeddings.device())?;
Some((points, labels))
}
let points = if points.is_empty() {
None
} else {
let n_points = points.len();
let mut coords = vec![];
points.iter().for_each(|(x, y)| {
let x = (*x as f32) * (original_w as f32);
let y = (*y as f32) * (original_h as f32);
coords.push(x);
coords.push(y);
});
let points = Tensor::from_vec(coords, (n_points, 1, 2), img_embeddings.device())?;
let labels = Tensor::ones((n_points, 1), DType::F32, img_embeddings.device())?;
Some((points, labels))
};
let points = points.as_ref().map(|(x, y)| (x, y));
let (sparse_prompt_embeddings, dense_prompt_embeddings) =