Automatic mask generator + point base mask (#773)

* Add more to the automatic mask generator.

* Add the target point.

* Fix.

* Remove the allow-unused.

* Mask post-processing.
This commit is contained in:
Laurent Mazare
2023-09-08 12:26:56 +01:00
committed by GitHub
parent c1453f00b1
commit 28c87f6a34
7 changed files with 249 additions and 42 deletions

View File

@ -1,5 +1,5 @@
use candle::{DType, IndexOp, Result, Tensor, D};
use candle_nn::{Linear, Module, VarBuilder};
use candle_nn::VarBuilder;
#[derive(Debug)]
struct PostionEmbeddingRandom {
@ -24,7 +24,6 @@ impl PostionEmbeddingRandom {
fn forward(&self, h: usize, w: usize) -> Result<Tensor> {
let device = self.positional_encoding_gaussian_matrix.device();
let grid = Tensor::ones((h, w), DType::F32, device)?;
let x_embed = (Tensor::arange(0u32, w as u32, device)?.to_dtype(DType::F32)? + 0.5)?;
let y_embed = (Tensor::arange(0u32, h as u32, device)?.to_dtype(DType::F32)? + 0.5)?;
let x_embed = (x_embed / w as f64)?
@ -157,8 +156,9 @@ impl PromptEncoder {
let point_embedding = self
.pe_layer
.forward_with_coords(&points, self.input_image_size)?;
let labels = labels.unsqueeze(2)?.broadcast_as(point_embedding.shape())?;
let zeros = point_embedding.zeros_like()?;
let point_embeddings = labels.lt(&labels.zeros_like()?)?.where_cond(
let point_embedding = labels.lt(&labels.zeros_like()?)?.where_cond(
&self
.not_a_point_embed
.embeddings()