Fix the multiple points case for sam. (#998)

This commit is contained in:
Laurent Mazare
2023-09-29 22:39:43 +02:00
committed by GitHub
parent 0ac2db577b
commit d188d6a764

View File

@ -171,8 +171,8 @@ impl Sam {
[x, y]
})
.collect::<Vec<_>>();
let points = Tensor::from_vec(xys, (n_points, 1, 2), img_embeddings.device())?;
let labels = Tensor::ones((n_points, 1), DType::F32, img_embeddings.device())?;
let points = Tensor::from_vec(xys, (1, n_points, 2), img_embeddings.device())?;
let labels = Tensor::ones((1, n_points), DType::F32, img_embeddings.device())?;
Some((points, labels))
};
let points = points.as_ref().map(|(x, y)| (x, y));