Use a single flag for the point argument. (#958)

This commit is contained in:
Laurent Mazare
2023-09-25 12:53:24 +01:00
committed by GitHub
parent 7f2bbcf746
commit a36d883254
3 changed files with 31 additions and 31 deletions

View File

@ -163,14 +163,15 @@ impl Sam {
None
} else {
let n_points = points.len();
let mut coords = vec![];
points.iter().for_each(|(x, y)| {
let x = (*x as f32) * (original_w as f32);
let y = (*y as f32) * (original_h as f32);
coords.push(x);
coords.push(y);
});
let points = Tensor::from_vec(coords, (n_points, 1, 2), img_embeddings.device())?;
let xys = points
.iter()
.flat_map(|(x, y)| {
let x = (*x as f32) * (original_w as f32);
let y = (*y as f32) * (original_h as f32);
[x, y]
})
.collect::<Vec<_>>();
let points = Tensor::from_vec(xys, (n_points, 1, 2), img_embeddings.device())?;
let labels = Tensor::ones((n_points, 1), DType::F32, img_embeddings.device())?;
Some((points, labels))
};