mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Get the comparison operation to work on scalar values. (#780)
* Get the comparison operation to work on scalar values. * Add some time measurement.
This commit is contained in:
@ -161,21 +161,21 @@ impl PromptEncoder {
|
||||
.forward_with_coords(&points, self.input_image_size)?;
|
||||
let labels = labels.unsqueeze(2)?.broadcast_as(point_embedding.shape())?;
|
||||
let zeros = point_embedding.zeros_like()?;
|
||||
let point_embedding = labels.lt(&labels.zeros_like()?)?.where_cond(
|
||||
let point_embedding = labels.lt(0f32)?.where_cond(
|
||||
&self
|
||||
.not_a_point_embed
|
||||
.embeddings()
|
||||
.broadcast_as(zeros.shape())?,
|
||||
&point_embedding,
|
||||
)?;
|
||||
let labels0 = labels.eq(&labels.zeros_like()?)?.where_cond(
|
||||
let labels0 = labels.eq(0f32)?.where_cond(
|
||||
&self.point_embeddings[0]
|
||||
.embeddings()
|
||||
.broadcast_as(zeros.shape())?,
|
||||
&zeros,
|
||||
)?;
|
||||
let point_embedding = (point_embedding + labels0)?;
|
||||
let labels1 = labels.eq(&labels.ones_like()?)?.where_cond(
|
||||
let labels1 = labels.eq(1f32)?.where_cond(
|
||||
&self.point_embeddings[1]
|
||||
.embeddings()
|
||||
.broadcast_as(zeros.shape())?,
|
||||
|
Reference in New Issue
Block a user