mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
More segment-anything. (#763)
* More segment-anything. * Split the model in multiple files. * Start adding the transformer. * Add the attention block. * Move the MLP Block.
This commit is contained in:
222
candle-examples/examples/segment-anything/model_mask_decoder.rs
Normal file
222
candle-examples/examples/segment-anything/model_mask_decoder.rs
Normal file
@ -0,0 +1,222 @@
|
||||
use candle::{DType, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
|
||||
|
||||
#[derive(Debug)]
|
||||
struct MlpMaskDecoder {
|
||||
layers: Vec<Linear>,
|
||||
sigmoid_output: bool,
|
||||
}
|
||||
|
||||
impl MlpMaskDecoder {
|
||||
fn new(
|
||||
input_dim: usize,
|
||||
hidden_dim: usize,
|
||||
output_dim: usize,
|
||||
num_layers: usize,
|
||||
sigmoid_output: bool,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let mut layers = Vec::with_capacity(num_layers);
|
||||
let vb = vb.pp("layers");
|
||||
for i in 0..num_layers {
|
||||
let in_dim = if i == 0 { input_dim } else { hidden_dim };
|
||||
let out_dim = if i + 1 == num_layers {
|
||||
output_dim
|
||||
} else {
|
||||
hidden_dim
|
||||
};
|
||||
let layer = crate::linear(vb.pp(i), in_dim, out_dim, true)?;
|
||||
layers.push(layer)
|
||||
}
|
||||
Ok(Self {
|
||||
layers,
|
||||
sigmoid_output,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for MlpMaskDecoder {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let mut xs = xs.clone();
|
||||
for (i, layer) in self.layers.iter().enumerate() {
|
||||
xs = layer.forward(&xs)?;
|
||||
if i + 1 < self.layers.len() {
|
||||
xs = xs.relu()?
|
||||
}
|
||||
}
|
||||
if self.sigmoid_output {
|
||||
candle_nn::ops::sigmoid(&xs)
|
||||
} else {
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct MaskDecoder {
|
||||
iou_token: candle_nn::Embedding,
|
||||
mask_tokens: candle_nn::Embedding,
|
||||
iou_prediction_head: MlpMaskDecoder,
|
||||
output_upscaling_conv1: candle_nn::ConvTranspose2d,
|
||||
output_upscaling_ln: LayerNorm,
|
||||
output_upscaling_conv2: candle_nn::ConvTranspose2d,
|
||||
num_mask_tokens: usize,
|
||||
output_hypernetworks_mlps: Vec<MlpMaskDecoder>,
|
||||
}
|
||||
|
||||
impl MaskDecoder {
|
||||
fn new(
|
||||
transformer_dim: usize,
|
||||
num_multimask_outputs: usize,
|
||||
iou_head_depth: usize,
|
||||
iou_head_hidden_dim: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let num_mask_tokens = num_multimask_outputs - 1;
|
||||
let iou_prediction_head = MlpMaskDecoder::new(
|
||||
transformer_dim,
|
||||
iou_head_hidden_dim,
|
||||
num_mask_tokens,
|
||||
iou_head_depth,
|
||||
false,
|
||||
vb.pp("iou_prediction_head"),
|
||||
)?;
|
||||
let iou_token = candle_nn::embedding(1, transformer_dim, vb.pp("iou_token"))?;
|
||||
let mask_tokens =
|
||||
candle_nn::embedding(num_mask_tokens, transformer_dim, vb.pp("mask_tokens"))?;
|
||||
let cfg = candle_nn::ConvTranspose2dConfig {
|
||||
stride: 2,
|
||||
..Default::default()
|
||||
};
|
||||
let output_upscaling_conv1 = candle_nn::conv_transpose2d(
|
||||
transformer_dim,
|
||||
transformer_dim / 4,
|
||||
2,
|
||||
cfg,
|
||||
vb.pp("output_upscaling.0"),
|
||||
)?;
|
||||
let output_upscaling_ln =
|
||||
layer_norm(transformer_dim / 4, 1e-6, vb.pp("output_upscaling.1"))?;
|
||||
let output_upscaling_conv2 = candle_nn::conv_transpose2d(
|
||||
transformer_dim / 4,
|
||||
transformer_dim / 8,
|
||||
2,
|
||||
cfg,
|
||||
vb.pp("output_upscaling.3"),
|
||||
)?;
|
||||
let mut output_hypernetworks_mlps = Vec::with_capacity(num_mask_tokens);
|
||||
let vb_o = vb.pp("output_hypernetworks_mlps");
|
||||
for i in 0..num_mask_tokens {
|
||||
let mlp = MlpMaskDecoder::new(
|
||||
transformer_dim,
|
||||
transformer_dim,
|
||||
transformer_dim / 8,
|
||||
3,
|
||||
false,
|
||||
vb_o.pp(i),
|
||||
)?;
|
||||
output_hypernetworks_mlps.push(mlp)
|
||||
}
|
||||
Ok(Self {
|
||||
iou_token,
|
||||
mask_tokens,
|
||||
iou_prediction_head,
|
||||
output_upscaling_conv1,
|
||||
output_upscaling_ln,
|
||||
output_upscaling_conv2,
|
||||
num_mask_tokens,
|
||||
output_hypernetworks_mlps,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&self,
|
||||
image_embeddings: &Tensor,
|
||||
image_pe: &Tensor,
|
||||
sparse_prompt_embeddings: &Tensor,
|
||||
dense_prompt_embeddings: &Tensor,
|
||||
multimask_output: bool,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
let (masks, iou_pred) = self.predict_masks(
|
||||
image_embeddings,
|
||||
image_pe,
|
||||
sparse_prompt_embeddings,
|
||||
dense_prompt_embeddings,
|
||||
)?;
|
||||
let masks = if multimask_output {
|
||||
masks.i((.., 1..))?
|
||||
} else {
|
||||
masks.i((.., 0..1))?
|
||||
};
|
||||
let iou_pred = if multimask_output {
|
||||
iou_pred.i((.., 1..))?
|
||||
} else {
|
||||
iou_pred.i((.., 0..1))?
|
||||
};
|
||||
Ok((masks, iou_pred))
|
||||
}
|
||||
|
||||
fn predict_masks(
|
||||
&self,
|
||||
image_embeddings: &Tensor,
|
||||
image_pe: &Tensor,
|
||||
sparse_prompt_embeddings: &Tensor,
|
||||
dense_prompt_embeddings: &Tensor,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
// Concatenate ouput tokens.
|
||||
let output_tokens = Tensor::cat(
|
||||
&[self.iou_token.embeddings(), self.mask_tokens.embeddings()],
|
||||
0,
|
||||
)?;
|
||||
let (d1, d2) = output_tokens.dims2()?;
|
||||
let output_tokens =
|
||||
output_tokens
|
||||
.unsqueeze(0)?
|
||||
.expand((sparse_prompt_embeddings.dim(0)?, d1, d2))?;
|
||||
let tokens = Tensor::cat(&[&output_tokens, sparse_prompt_embeddings], 1)?;
|
||||
|
||||
// Expand per-image data in batch direction to be per mask
|
||||
let src = repeat_interleave(image_embeddings, tokens.dim(0)?, 0)?;
|
||||
let src = (src + dense_prompt_embeddings)?;
|
||||
let pos_src = repeat_interleave(image_pe, tokens.dim(0)?, 0)?;
|
||||
let (b, c, h, w) = src.dims4()?;
|
||||
|
||||
// Run the transformer
|
||||
let (hs, src) = run_transformer(&src, &pos_src, &tokens)?;
|
||||
let iou_token_out = hs.i((.., 0))?;
|
||||
let mask_tokens_out = hs.i((.., 1, 1 + self.num_mask_tokens))?;
|
||||
|
||||
// Upscale mask embeddings and predict masks using the masks tokens.
|
||||
let src = src.transpose(1, 2)?.reshape((b, c, h, w))?;
|
||||
let upscaled_embedding = self
|
||||
.output_upscaling_conv1
|
||||
.forward(&src)?
|
||||
.apply(&self.output_upscaling_ln)?
|
||||
.gelu()?
|
||||
.apply(&self.output_upscaling_conv2)?
|
||||
.gelu()?;
|
||||
let mut hyper_in_list = Vec::with_capacity(self.num_mask_tokens);
|
||||
for (i, mlp) in self.output_hypernetworks_mlps.iter().enumerate() {
|
||||
let h = mlp.forward(&mask_tokens_out.i((.., i))?)?;
|
||||
hyper_in_list.push(h)
|
||||
}
|
||||
let hyper_in = Tensor::stack(hyper_in_list.as_slice(), 1)?;
|
||||
let (b, c, h, w) = upscaled_embedding.dims4()?;
|
||||
let masks = hyper_in
|
||||
.matmul(&upscaled_embedding.reshape((b, c, h * w))?)?
|
||||
.reshape((b, 0, h, w))?;
|
||||
|
||||
// Generate mask quality predictions.
|
||||
let iou_pred = self.iou_prediction_head.forward(&iou_token_out)?;
|
||||
Ok((masks, iou_pred))
|
||||
}
|
||||
}
|
||||
|
||||
// Equivalent to torch.repeat_interleave
|
||||
fn repeat_interleave(_img: &Tensor, _repeats: usize, _dim: usize) -> Result<Tensor> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn run_transformer(_src: &Tensor, _pos: &Tensor, _tokens: &Tensor) -> Result<(Tensor, Tensor)> {
|
||||
todo!()
|
||||
}
|
Reference in New Issue
Block a user