mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Improve the safetensor loading in the segment-anything example. (#772)
* Improve the safetensor loading in the segment-anything example. * Properly handle the labels when embedding the point prompts.
This commit is contained in:
@ -110,7 +110,7 @@ pub fn main() -> anyhow::Result<()> {
|
||||
|
||||
let image = if args.image.ends_with(".safetensors") {
|
||||
let mut tensors = candle::safetensors::load(&args.image, &device)?;
|
||||
match tensors.remove("image") {
|
||||
let image = match tensors.remove("image") {
|
||||
Some(image) => image,
|
||||
None => {
|
||||
if tensors.len() != 1 {
|
||||
@ -118,6 +118,11 @@ pub fn main() -> anyhow::Result<()> {
|
||||
}
|
||||
tensors.into_values().next().unwrap()
|
||||
}
|
||||
};
|
||||
if image.rank() == 4 {
|
||||
image.get(0)?
|
||||
} else {
|
||||
image
|
||||
}
|
||||
} else {
|
||||
candle_examples::load_image(args.image, Some(model_sam::IMAGE_SIZE))?.to_device(&device)?
|
||||
|
@ -157,7 +157,28 @@ impl PromptEncoder {
|
||||
let point_embedding = self
|
||||
.pe_layer
|
||||
.forward_with_coords(&points, self.input_image_size)?;
|
||||
// TODO: tweak based on labels.
|
||||
let zeros = point_embedding.zeros_like()?;
|
||||
let point_embeddings = labels.lt(&labels.zeros_like()?)?.where_cond(
|
||||
&self
|
||||
.not_a_point_embed
|
||||
.embeddings()
|
||||
.broadcast_as(zeros.shape())?,
|
||||
&point_embedding,
|
||||
)?;
|
||||
let labels0 = labels.eq(&labels.zeros_like()?)?.where_cond(
|
||||
&self.point_embeddings[0]
|
||||
.embeddings()
|
||||
.broadcast_as(zeros.shape())?,
|
||||
&zeros,
|
||||
)?;
|
||||
let point_embedding = (point_embedding + labels0)?;
|
||||
let labels1 = labels.eq(&labels.ones_like()?)?.where_cond(
|
||||
&self.point_embeddings[1]
|
||||
.embeddings()
|
||||
.broadcast_as(zeros.shape())?,
|
||||
&zeros,
|
||||
)?;
|
||||
let point_embedding = (point_embedding + labels1)?;
|
||||
Ok(point_embedding)
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user