Automatic mask generator + point base mask (#773)

* Add more to the automatic mask generator.

* Add the target point.

* Fix.

* Remove the allow-unused.

* Mask post-processing.
This commit is contained in:
Laurent Mazare
2023-09-08 12:26:56 +01:00
committed by GitHub
parent c1453f00b1
commit 28c87f6a34
7 changed files with 249 additions and 42 deletions

View File

@ -1,4 +1,4 @@
use candle::{DType, IndexOp, Result, Tensor, D};
use candle::{IndexOp, Result, Tensor};
use candle_nn::{Linear, Module, VarBuilder};
use crate::model_transformer::TwoWayTransformer;
@ -188,7 +188,7 @@ impl MaskDecoder {
// Expand per-image data in batch direction to be per mask
let src = repeat_interleave(image_embeddings, tokens.dim(0)?, 0)?;
let src = (src + dense_prompt_embeddings)?;
let src = src.broadcast_add(dense_prompt_embeddings)?;
let pos_src = repeat_interleave(image_pe, tokens.dim(0)?, 0)?;
let (b, c, h, w) = src.dims4()?;