mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 20:09:50 +00:00
[segment-anything] Support multi-point as the prompt input (#945)
* [sam] Support multi-point prompts * [segment-anything] Pass points by reference * [segment-anything] Update example code and image * Fix clippy lint. --------- Co-authored-by: Yun Ding <yunding@nvidia.com> Co-authored-by: laurent <laurent.mazare@gmail.com>
This commit is contained in:
@ -130,7 +130,7 @@ impl Sam {
|
||||
pub fn forward(
|
||||
&self,
|
||||
img: &Tensor,
|
||||
point: Option<(f64, f64)>,
|
||||
points: &[(f64, f64)],
|
||||
multimask_output: bool,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
let (_c, original_h, original_w) = img.dims3()?;
|
||||
@ -140,7 +140,7 @@ impl Sam {
|
||||
&img_embeddings,
|
||||
original_h,
|
||||
original_w,
|
||||
point,
|
||||
points,
|
||||
multimask_output,
|
||||
)?;
|
||||
let mask = low_res_mask
|
||||
@ -155,20 +155,24 @@ impl Sam {
|
||||
img_embeddings: &Tensor,
|
||||
original_h: usize,
|
||||
original_w: usize,
|
||||
point: Option<(f64, f64)>,
|
||||
points: &[(f64, f64)],
|
||||
multimask_output: bool,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
let image_pe = self.prompt_encoder.get_dense_pe()?;
|
||||
let points = match point {
|
||||
None => None,
|
||||
Some((x, y)) => {
|
||||
let points = Tensor::new(
|
||||
&[[[x as f32 * original_w as f32, y as f32 * original_h as f32]]],
|
||||
img_embeddings.device(),
|
||||
)?;
|
||||
let labels = Tensor::ones((1, 1), DType::F32, img_embeddings.device())?;
|
||||
Some((points, labels))
|
||||
}
|
||||
let points = if points.is_empty() {
|
||||
None
|
||||
} else {
|
||||
let n_points = points.len();
|
||||
let mut coords = vec![];
|
||||
points.iter().for_each(|(x, y)| {
|
||||
let x = (*x as f32) * (original_w as f32);
|
||||
let y = (*y as f32) * (original_h as f32);
|
||||
coords.push(x);
|
||||
coords.push(y);
|
||||
});
|
||||
let points = Tensor::from_vec(coords, (n_points, 1, 2), img_embeddings.device())?;
|
||||
let labels = Tensor::ones((n_points, 1), DType::F32, img_embeddings.device())?;
|
||||
Some((points, labels))
|
||||
};
|
||||
let points = points.as_ref().map(|(x, y)| (x, y));
|
||||
let (sparse_prompt_embeddings, dense_prompt_embeddings) =
|
||||
|
Reference in New Issue
Block a user