mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
Segment Anything - process images (#766)
* Start processing images. * Add LayerNorm2d. * Properly use LayerNorm2d. * Tweak eps. * Use LayerNorm on inputs with a rank different from 3. * Window partitioning. * Fix a couple todos. * More todos. * Hard-code the einsums. * More padding support. * Some sizes tweaks. * Use the hub to get the weights. * Use a batch matmul. * Tweaks. * More fixes. * Get some predictions to be generated.
This commit is contained in:
@ -1,5 +1,5 @@
|
||||
use candle::{DType, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
|
||||
use candle_nn::{Linear, Module, VarBuilder};
|
||||
|
||||
use crate::model_transformer::TwoWayTransformer;
|
||||
|
||||
@ -60,7 +60,7 @@ pub struct MaskDecoder {
|
||||
mask_tokens: candle_nn::Embedding,
|
||||
iou_prediction_head: MlpMaskDecoder,
|
||||
output_upscaling_conv1: candle_nn::ConvTranspose2d,
|
||||
output_upscaling_ln: LayerNorm,
|
||||
output_upscaling_ln: crate::LayerNorm2d,
|
||||
output_upscaling_conv2: candle_nn::ConvTranspose2d,
|
||||
num_mask_tokens: usize,
|
||||
output_hypernetworks_mlps: Vec<MlpMaskDecoder>,
|
||||
@ -99,7 +99,7 @@ impl MaskDecoder {
|
||||
vb.pp("output_upscaling.0"),
|
||||
)?;
|
||||
let output_upscaling_ln =
|
||||
layer_norm(transformer_dim / 4, 1e-6, vb.pp("output_upscaling.1"))?;
|
||||
crate::LayerNorm2d::new(transformer_dim / 4, 1e-6, vb.pp("output_upscaling.1"))?;
|
||||
let output_upscaling_conv2 = candle_nn::conv_transpose2d(
|
||||
transformer_dim / 4,
|
||||
transformer_dim / 8,
|
||||
@ -140,7 +140,7 @@ impl MaskDecoder {
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(
|
||||
pub fn forward(
|
||||
&self,
|
||||
image_embeddings: &Tensor,
|
||||
image_pe: &Tensor,
|
||||
@ -195,7 +195,7 @@ impl MaskDecoder {
|
||||
// Run the transformer
|
||||
let (hs, src) = self.transformer.forward(&src, &pos_src, &tokens)?;
|
||||
let iou_token_out = hs.i((.., 0))?;
|
||||
let mask_tokens_out = hs.i((.., 1, 1 + self.num_mask_tokens))?;
|
||||
let mask_tokens_out = hs.i((.., 1..1 + self.num_mask_tokens))?;
|
||||
|
||||
// Upscale mask embeddings and predict masks using the masks tokens.
|
||||
let src = src.transpose(1, 2)?.reshape((b, c, h, w))?;
|
||||
@ -213,9 +213,8 @@ impl MaskDecoder {
|
||||
}
|
||||
let hyper_in = Tensor::stack(hyper_in_list.as_slice(), 1)?;
|
||||
let (b, c, h, w) = upscaled_embedding.dims4()?;
|
||||
let masks = hyper_in
|
||||
.matmul(&upscaled_embedding.reshape((b, c, h * w))?)?
|
||||
.reshape((b, 0, h, w))?;
|
||||
let masks = hyper_in.matmul(&upscaled_embedding.reshape((b, c, h * w))?)?;
|
||||
let masks = masks.reshape((b, masks.elem_count() / b / h / w, h, w))?;
|
||||
|
||||
// Generate mask quality predictions.
|
||||
let iou_pred = self.iou_prediction_head.forward(&iou_token_out)?;
|
||||
@ -224,6 +223,9 @@ impl MaskDecoder {
|
||||
}
|
||||
|
||||
// Equivalent to torch.repeat_interleave
|
||||
fn repeat_interleave(_img: &Tensor, _repeats: usize, _dim: usize) -> Result<Tensor> {
|
||||
todo!()
|
||||
fn repeat_interleave(img: &Tensor, repeats: usize, dim: usize) -> Result<Tensor> {
|
||||
let img = img.unsqueeze(dim + 1)?;
|
||||
let mut dims = img.dims().to_vec();
|
||||
dims[dim + 1] = repeats;
|
||||
img.broadcast_as(dims)?.flatten(dim, dim + 1)
|
||||
}
|
||||
|
Reference in New Issue
Block a user