diff --git a/candle-examples/examples/segment-anything/model_image_encoder.rs b/candle-examples/examples/segment-anything/model_image_encoder.rs index f5db2830..79e52d47 100644 --- a/candle-examples/examples/segment-anything/model_image_encoder.rs +++ b/candle-examples/examples/segment-anything/model_image_encoder.rs @@ -123,7 +123,7 @@ fn get_rel_pos(q_size: usize, k_size: usize, rel_pos: &Tensor) -> Result let relative_coords = relative_coords.to_dtype(DType::U32)?; rel_pos_resized .index_select(&relative_coords.reshape(d1 * d2)?, 0)? - .reshape((d1, d2, rel_pos_resized.dim(1)?)) + .reshape((d1, d2, ())) } impl Module for Attention { @@ -243,7 +243,7 @@ fn window_unpartition( ))? .transpose(2, 3)? .contiguous()? - .reshape((b, h_p, w_p, windows.elem_count() / b / h_p / w_p))?; + .reshape((b, h_p, w_p, ()))?; let xs = if h_p > h { xs.narrow(1, 0, h)? } else { xs }; let xs = if w_p > w { xs.narrow(2, 0, w)? } else { xs }; Ok(xs) diff --git a/candle-examples/examples/segment-anything/model_mask_decoder.rs b/candle-examples/examples/segment-anything/model_mask_decoder.rs index 1ef46eeb..acbfeeea 100644 --- a/candle-examples/examples/segment-anything/model_mask_decoder.rs +++ b/candle-examples/examples/segment-anything/model_mask_decoder.rs @@ -214,7 +214,7 @@ impl MaskDecoder { let hyper_in = Tensor::stack(hyper_in_list.as_slice(), 1)?; let (b, c, h, w) = upscaled_embedding.dims4()?; let masks = hyper_in.matmul(&upscaled_embedding.reshape((b, c, h * w))?)?; - let masks = masks.reshape((b, masks.elem_count() / b / h / w, h, w))?; + let masks = masks.reshape((b, (), h, w))?; // Generate mask quality predictions. let iou_pred = self.iou_prediction_head.forward(&iou_token_out)?; diff --git a/candle-examples/examples/segment-anything/model_prompt_encoder.rs b/candle-examples/examples/segment-anything/model_prompt_encoder.rs index c6ffffd2..aab0c4fd 100644 --- a/candle-examples/examples/segment-anything/model_prompt_encoder.rs +++ b/candle-examples/examples/segment-anything/model_prompt_encoder.rs @@ -28,10 +28,10 @@ impl PostionEmbeddingRandom { let x_embed = (Tensor::arange(0u32, w as u32, device)?.to_dtype(DType::F32)? + 0.5)?; let y_embed = (Tensor::arange(0u32, h as u32, device)?.to_dtype(DType::F32)? + 0.5)?; let x_embed = (x_embed / w as f64)? - .reshape((1, w))? + .reshape((1, ()))? .broadcast_as((h, w))?; let y_embed = (y_embed / h as f64)? - .reshape((h, 1))? + .reshape(((), 1))? .broadcast_as((h, w))?; let coords = Tensor::stack(&[&x_embed, &y_embed], D::Minus1)?; self.pe_encoding(&coords)?.permute((2, 0, 1)) @@ -163,7 +163,7 @@ impl PromptEncoder { fn embed_boxes(&self, boxes: &Tensor) -> Result { let boxes = (boxes + 0.5)?; - let coords = boxes.reshape((boxes.elem_count() / 4, 2, 2))?; + let coords = boxes.reshape(((), 2, 2))?; let corner_embedding = self .pe_layer .forward_with_coords(&coords, self.input_image_size)?; @@ -200,7 +200,7 @@ impl PromptEncoder { let dense_embeddings = match masks { None => { let emb = self.no_mask_embed.embeddings(); - emb.reshape((1, emb.elem_count(), 1, 1))?.expand(( + emb.reshape((1, (), 1, 1))?.expand(( 1, emb.elem_count(), self.image_embedding_size.0,