mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Improve the safetensor loading in the segment-anything example. (#772)
* Improve the safetensor loading in the segment-anything example. * Properly handle the labels when embedding the point prompts.
This commit is contained in:
@ -157,7 +157,28 @@ impl PromptEncoder {
|
||||
let point_embedding = self
|
||||
.pe_layer
|
||||
.forward_with_coords(&points, self.input_image_size)?;
|
||||
// TODO: tweak based on labels.
|
||||
let zeros = point_embedding.zeros_like()?;
|
||||
let point_embeddings = labels.lt(&labels.zeros_like()?)?.where_cond(
|
||||
&self
|
||||
.not_a_point_embed
|
||||
.embeddings()
|
||||
.broadcast_as(zeros.shape())?,
|
||||
&point_embedding,
|
||||
)?;
|
||||
let labels0 = labels.eq(&labels.zeros_like()?)?.where_cond(
|
||||
&self.point_embeddings[0]
|
||||
.embeddings()
|
||||
.broadcast_as(zeros.shape())?,
|
||||
&zeros,
|
||||
)?;
|
||||
let point_embedding = (point_embedding + labels0)?;
|
||||
let labels1 = labels.eq(&labels.ones_like()?)?.where_cond(
|
||||
&self.point_embeddings[1]
|
||||
.embeddings()
|
||||
.broadcast_as(zeros.shape())?,
|
||||
&zeros,
|
||||
)?;
|
||||
let point_embedding = (point_embedding + labels1)?;
|
||||
Ok(point_embedding)
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user