Improve the safetensor loading in the segment-anything example. (#772)

* Improve the safetensor loading in the segment-anything example.

* Properly handle the labels when embedding the point prompts.
This commit is contained in:
Laurent Mazare
2023-09-08 09:39:10 +01:00
committed by GitHub
parent 989a4807b1
commit c1453f00b1
2 changed files with 28 additions and 2 deletions

View File

@ -110,7 +110,7 @@ pub fn main() -> anyhow::Result<()> {
let image = if args.image.ends_with(".safetensors") {
let mut tensors = candle::safetensors::load(&args.image, &device)?;
match tensors.remove("image") {
let image = match tensors.remove("image") {
Some(image) => image,
None => {
if tensors.len() != 1 {
@ -118,6 +118,11 @@ pub fn main() -> anyhow::Result<()> {
}
tensors.into_values().next().unwrap()
}
};
if image.rank() == 4 {
image.get(0)?
} else {
image
}
} else {
candle_examples::load_image(args.image, Some(model_sam::IMAGE_SIZE))?.to_device(&device)?

View File

@ -157,7 +157,28 @@ impl PromptEncoder {
let point_embedding = self
.pe_layer
.forward_with_coords(&points, self.input_image_size)?;
// TODO: tweak based on labels.
let zeros = point_embedding.zeros_like()?;
let point_embeddings = labels.lt(&labels.zeros_like()?)?.where_cond(
&self
.not_a_point_embed
.embeddings()
.broadcast_as(zeros.shape())?,
&point_embedding,
)?;
let labels0 = labels.eq(&labels.zeros_like()?)?.where_cond(
&self.point_embeddings[0]
.embeddings()
.broadcast_as(zeros.shape())?,
&zeros,
)?;
let point_embedding = (point_embedding + labels0)?;
let labels1 = labels.eq(&labels.ones_like()?)?.where_cond(
&self.point_embeddings[1]
.embeddings()
.broadcast_as(zeros.shape())?,
&zeros,
)?;
let point_embedding = (point_embedding + labels1)?;
Ok(point_embedding)
}