From d188d6a7642c470732f740e93c035fd792c9706c Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Fri, 29 Sep 2023 22:39:43 +0200 Subject: [PATCH] Fix the multiple points case for sam. (#998) --- candle-transformers/src/models/segment_anything/sam.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/candle-transformers/src/models/segment_anything/sam.rs b/candle-transformers/src/models/segment_anything/sam.rs index 49e95adb..6de7beb2 100644 --- a/candle-transformers/src/models/segment_anything/sam.rs +++ b/candle-transformers/src/models/segment_anything/sam.rs @@ -171,8 +171,8 @@ impl Sam { [x, y] }) .collect::>(); - let points = Tensor::from_vec(xys, (n_points, 1, 2), img_embeddings.device())?; - let labels = Tensor::ones((n_points, 1), DType::F32, img_embeddings.device())?; + let points = Tensor::from_vec(xys, (1, n_points, 2), img_embeddings.device())?; + let labels = Tensor::ones((1, n_points), DType::F32, img_embeddings.device())?; Some((points, labels)) }; let points = points.as_ref().map(|(x, y)| (x, y));