mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 20:09:50 +00:00
Use a single flag for the point argument. (#958)
This commit is contained in:
@ -163,14 +163,15 @@ impl Sam {
|
||||
None
|
||||
} else {
|
||||
let n_points = points.len();
|
||||
let mut coords = vec![];
|
||||
points.iter().for_each(|(x, y)| {
|
||||
let x = (*x as f32) * (original_w as f32);
|
||||
let y = (*y as f32) * (original_h as f32);
|
||||
coords.push(x);
|
||||
coords.push(y);
|
||||
});
|
||||
let points = Tensor::from_vec(coords, (n_points, 1, 2), img_embeddings.device())?;
|
||||
let xys = points
|
||||
.iter()
|
||||
.flat_map(|(x, y)| {
|
||||
let x = (*x as f32) * (original_w as f32);
|
||||
let y = (*y as f32) * (original_h as f32);
|
||||
[x, y]
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let points = Tensor::from_vec(xys, (n_points, 1, 2), img_embeddings.device())?;
|
||||
let labels = Tensor::ones((n_points, 1), DType::F32, img_embeddings.device())?;
|
||||
Some((points, labels))
|
||||
};
|
||||
|
Reference in New Issue
Block a user