mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Automatic mask generator + point base mask (#773)
* Add more to the automatic mask generator. * Add the target point. * Fix. * Remove the allow-unused. * Mask post-processing.
This commit is contained in:
@ -1,5 +1,5 @@
|
||||
use candle::{DType, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{Linear, Module, VarBuilder};
|
||||
use candle_nn::VarBuilder;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct PostionEmbeddingRandom {
|
||||
@ -24,7 +24,6 @@ impl PostionEmbeddingRandom {
|
||||
|
||||
fn forward(&self, h: usize, w: usize) -> Result<Tensor> {
|
||||
let device = self.positional_encoding_gaussian_matrix.device();
|
||||
let grid = Tensor::ones((h, w), DType::F32, device)?;
|
||||
let x_embed = (Tensor::arange(0u32, w as u32, device)?.to_dtype(DType::F32)? + 0.5)?;
|
||||
let y_embed = (Tensor::arange(0u32, h as u32, device)?.to_dtype(DType::F32)? + 0.5)?;
|
||||
let x_embed = (x_embed / w as f64)?
|
||||
@ -157,8 +156,9 @@ impl PromptEncoder {
|
||||
let point_embedding = self
|
||||
.pe_layer
|
||||
.forward_with_coords(&points, self.input_image_size)?;
|
||||
let labels = labels.unsqueeze(2)?.broadcast_as(point_embedding.shape())?;
|
||||
let zeros = point_embedding.zeros_like()?;
|
||||
let point_embeddings = labels.lt(&labels.zeros_like()?)?.where_cond(
|
||||
let point_embedding = labels.lt(&labels.zeros_like()?)?.where_cond(
|
||||
&self
|
||||
.not_a_point_embed
|
||||
.embeddings()
|
||||
|
Reference in New Issue
Block a user