mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Improve the safetensor loading in the segment-anything example. (#772)
* Improve the safetensor loading in the segment-anything example. * Properly handle the labels when embedding the point prompts.
This commit is contained in:
@ -110,7 +110,7 @@ pub fn main() -> anyhow::Result<()> {
|
|||||||
|
|
||||||
let image = if args.image.ends_with(".safetensors") {
|
let image = if args.image.ends_with(".safetensors") {
|
||||||
let mut tensors = candle::safetensors::load(&args.image, &device)?;
|
let mut tensors = candle::safetensors::load(&args.image, &device)?;
|
||||||
match tensors.remove("image") {
|
let image = match tensors.remove("image") {
|
||||||
Some(image) => image,
|
Some(image) => image,
|
||||||
None => {
|
None => {
|
||||||
if tensors.len() != 1 {
|
if tensors.len() != 1 {
|
||||||
@ -118,6 +118,11 @@ pub fn main() -> anyhow::Result<()> {
|
|||||||
}
|
}
|
||||||
tensors.into_values().next().unwrap()
|
tensors.into_values().next().unwrap()
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
if image.rank() == 4 {
|
||||||
|
image.get(0)?
|
||||||
|
} else {
|
||||||
|
image
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
candle_examples::load_image(args.image, Some(model_sam::IMAGE_SIZE))?.to_device(&device)?
|
candle_examples::load_image(args.image, Some(model_sam::IMAGE_SIZE))?.to_device(&device)?
|
||||||
|
@ -157,7 +157,28 @@ impl PromptEncoder {
|
|||||||
let point_embedding = self
|
let point_embedding = self
|
||||||
.pe_layer
|
.pe_layer
|
||||||
.forward_with_coords(&points, self.input_image_size)?;
|
.forward_with_coords(&points, self.input_image_size)?;
|
||||||
// TODO: tweak based on labels.
|
let zeros = point_embedding.zeros_like()?;
|
||||||
|
let point_embeddings = labels.lt(&labels.zeros_like()?)?.where_cond(
|
||||||
|
&self
|
||||||
|
.not_a_point_embed
|
||||||
|
.embeddings()
|
||||||
|
.broadcast_as(zeros.shape())?,
|
||||||
|
&point_embedding,
|
||||||
|
)?;
|
||||||
|
let labels0 = labels.eq(&labels.zeros_like()?)?.where_cond(
|
||||||
|
&self.point_embeddings[0]
|
||||||
|
.embeddings()
|
||||||
|
.broadcast_as(zeros.shape())?,
|
||||||
|
&zeros,
|
||||||
|
)?;
|
||||||
|
let point_embedding = (point_embedding + labels0)?;
|
||||||
|
let labels1 = labels.eq(&labels.ones_like()?)?.where_cond(
|
||||||
|
&self.point_embeddings[1]
|
||||||
|
.embeddings()
|
||||||
|
.broadcast_as(zeros.shape())?,
|
||||||
|
&zeros,
|
||||||
|
)?;
|
||||||
|
let point_embedding = (point_embedding + labels1)?;
|
||||||
Ok(point_embedding)
|
Ok(point_embedding)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user