mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Get the comparison operation to work on scalar values. (#780)
* Get the comparison operation to work on scalar values. * Add some time measurement.
This commit is contained in:
@ -209,12 +209,17 @@ pub fn main() -> anyhow::Result<()> {
|
||||
}
|
||||
} else {
|
||||
let point = Some((args.point_x, args.point_y));
|
||||
let start_time = std::time::Instant::now();
|
||||
let (mask, iou_predictions) = sam.forward(&image, point, false)?;
|
||||
println!(
|
||||
"mask generated in {:.2}s",
|
||||
start_time.elapsed().as_secs_f32()
|
||||
);
|
||||
println!("mask:\n{mask}");
|
||||
println!("iou_predictions: {iou_predictions:?}");
|
||||
|
||||
// Save the mask as an image.
|
||||
let mask = (mask.ge(&mask.zeros_like()?)? * 255.)?;
|
||||
let mask = (mask.ge(0f32)? * 255.)?;
|
||||
let (_one, h, w) = mask.dims3()?;
|
||||
let mask = mask.expand((3, h, w))?;
|
||||
candle_examples::save_image_resize(&mask, "sam_mask.png", initial_h, initial_w)?;
|
||||
|
@ -161,21 +161,21 @@ impl PromptEncoder {
|
||||
.forward_with_coords(&points, self.input_image_size)?;
|
||||
let labels = labels.unsqueeze(2)?.broadcast_as(point_embedding.shape())?;
|
||||
let zeros = point_embedding.zeros_like()?;
|
||||
let point_embedding = labels.lt(&labels.zeros_like()?)?.where_cond(
|
||||
let point_embedding = labels.lt(0f32)?.where_cond(
|
||||
&self
|
||||
.not_a_point_embed
|
||||
.embeddings()
|
||||
.broadcast_as(zeros.shape())?,
|
||||
&point_embedding,
|
||||
)?;
|
||||
let labels0 = labels.eq(&labels.zeros_like()?)?.where_cond(
|
||||
let labels0 = labels.eq(0f32)?.where_cond(
|
||||
&self.point_embeddings[0]
|
||||
.embeddings()
|
||||
.broadcast_as(zeros.shape())?,
|
||||
&zeros,
|
||||
)?;
|
||||
let point_embedding = (point_embedding + labels0)?;
|
||||
let labels1 = labels.eq(&labels.ones_like()?)?.where_cond(
|
||||
let labels1 = labels.eq(1f32)?.where_cond(
|
||||
&self.point_embeddings[1]
|
||||
.embeddings()
|
||||
.broadcast_as(zeros.shape())?,
|
||||
|
Reference in New Issue
Block a user