Automatic mask generator + point base mask (#773)

* Add more to the automatic mask generator.

* Add the target point.

* Fix.

* Remove the allow-unused.

* Mask post-processing.
This commit is contained in:
Laurent Mazare
2023-09-08 12:26:56 +01:00
committed by GitHub
parent c1453f00b1
commit 28c87f6a34
7 changed files with 249 additions and 42 deletions

View File

@ -1,4 +1,4 @@
use candle::{DType, IndexOp, Result, Tensor, D};
use candle::{DType, IndexOp, Result, Tensor};
use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
#[derive(Debug)]
@ -37,7 +37,6 @@ struct Attention {
proj: Linear,
num_heads: usize,
scale: f64,
use_rel_pos: bool,
rel_pos_hw: Option<(Tensor, Tensor)>,
}
@ -66,7 +65,6 @@ impl Attention {
proj,
num_heads,
scale,
use_rel_pos,
rel_pos_hw,
})
}
@ -272,7 +270,6 @@ impl Module for Block {
#[derive(Debug)]
pub struct ImageEncoderViT {
img_size: usize,
patch_embed: PatchEmbed,
blocks: Vec<Block>,
neck_conv1: candle_nn::Conv2d,
@ -350,7 +347,6 @@ impl ImageEncoderViT {
None
};
Ok(Self {
img_size,
patch_embed,
blocks,
neck_conv1,