mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 03:28:50 +00:00
Automatic mask generator + point base mask (#773)
* Add more to the automatic mask generator. * Add the target point. * Fix. * Remove the allow-unused. * Mask post-processing.
This commit is contained in:
@ -1,4 +1,4 @@
|
||||
use candle::{DType, IndexOp, Result, Tensor, D};
|
||||
use candle::{IndexOp, Result, Tensor};
|
||||
use candle_nn::{Linear, Module, VarBuilder};
|
||||
|
||||
use crate::model_transformer::TwoWayTransformer;
|
||||
@ -188,7 +188,7 @@ impl MaskDecoder {
|
||||
|
||||
// Expand per-image data in batch direction to be per mask
|
||||
let src = repeat_interleave(image_embeddings, tokens.dim(0)?, 0)?;
|
||||
let src = (src + dense_prompt_embeddings)?;
|
||||
let src = src.broadcast_add(dense_prompt_embeddings)?;
|
||||
let pos_src = repeat_interleave(image_pe, tokens.dim(0)?, 0)?;
|
||||
let (b, c, h, w) = src.dims4()?;
|
||||
|
||||
|
Reference in New Issue
Block a user