Segment Anything - process images (#766)

* Start processing images.

* Add LayerNorm2d.

* Properly use LayerNorm2d.

* Tweak eps.

* Use LayerNorm on inputs with a rank different from 3.

* Window partitioning.

* Fix a couple todos.

* More todos.

* Hard-code the einsums.

* More padding support.

* Some sizes tweaks.

* Use the hub to get the weights.

* Use a batch matmul.

* Tweaks.

* More fixes.

* Get some predictions to be generated.
This commit is contained in:
Laurent Mazare
2023-09-07 19:22:45 +01:00
committed by GitHub
parent 7b50f3e106
commit 7396b8ed1a
10 changed files with 303 additions and 105 deletions

View File

@ -1,5 +1,5 @@
use candle::{DType, IndexOp, Result, Tensor, D};
use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
use candle_nn::{Linear, Module, VarBuilder};
use crate::model_transformer::TwoWayTransformer;
@ -60,7 +60,7 @@ pub struct MaskDecoder {
mask_tokens: candle_nn::Embedding,
iou_prediction_head: MlpMaskDecoder,
output_upscaling_conv1: candle_nn::ConvTranspose2d,
output_upscaling_ln: LayerNorm,
output_upscaling_ln: crate::LayerNorm2d,
output_upscaling_conv2: candle_nn::ConvTranspose2d,
num_mask_tokens: usize,
output_hypernetworks_mlps: Vec<MlpMaskDecoder>,
@ -99,7 +99,7 @@ impl MaskDecoder {
vb.pp("output_upscaling.0"),
)?;
let output_upscaling_ln =
layer_norm(transformer_dim / 4, 1e-6, vb.pp("output_upscaling.1"))?;
crate::LayerNorm2d::new(transformer_dim / 4, 1e-6, vb.pp("output_upscaling.1"))?;
let output_upscaling_conv2 = candle_nn::conv_transpose2d(
transformer_dim / 4,
transformer_dim / 8,
@ -140,7 +140,7 @@ impl MaskDecoder {
})
}
fn forward(
pub fn forward(
&self,
image_embeddings: &Tensor,
image_pe: &Tensor,
@ -195,7 +195,7 @@ impl MaskDecoder {
// Run the transformer
let (hs, src) = self.transformer.forward(&src, &pos_src, &tokens)?;
let iou_token_out = hs.i((.., 0))?;
let mask_tokens_out = hs.i((.., 1, 1 + self.num_mask_tokens))?;
let mask_tokens_out = hs.i((.., 1..1 + self.num_mask_tokens))?;
// Upscale mask embeddings and predict masks using the masks tokens.
let src = src.transpose(1, 2)?.reshape((b, c, h, w))?;
@ -213,9 +213,8 @@ impl MaskDecoder {
}
let hyper_in = Tensor::stack(hyper_in_list.as_slice(), 1)?;
let (b, c, h, w) = upscaled_embedding.dims4()?;
let masks = hyper_in
.matmul(&upscaled_embedding.reshape((b, c, h * w))?)?
.reshape((b, 0, h, w))?;
let masks = hyper_in.matmul(&upscaled_embedding.reshape((b, c, h * w))?)?;
let masks = masks.reshape((b, masks.elem_count() / b / h / w, h, w))?;
// Generate mask quality predictions.
let iou_pred = self.iou_prediction_head.forward(&iou_token_out)?;
@ -224,6 +223,9 @@ impl MaskDecoder {
}
// Equivalent to torch.repeat_interleave
fn repeat_interleave(_img: &Tensor, _repeats: usize, _dim: usize) -> Result<Tensor> {
todo!()
fn repeat_interleave(img: &Tensor, repeats: usize, dim: usize) -> Result<Tensor> {
let img = img.unsqueeze(dim + 1)?;
let mut dims = img.dims().to_vec();
dims[dim + 1] = repeats;
img.broadcast_as(dims)?.flatten(dim, dim + 1)
}