More segment-anything. (#763)

* More segment-anything.

* Split the model in multiple files.

* Start adding the transformer.

* Add the attention block.

* Move the MLP Block.
This commit is contained in:
Laurent Mazare
2023-09-07 08:28:30 +02:00
committed by GitHub
parent 000fa00e31
commit 8c991df394
5 changed files with 574 additions and 349 deletions

View File

@ -8,16 +8,15 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use clap::Parser;
mod model_image_encoder;
mod model_mask_decoder;
mod model_transformer;
use candle::{DType, IndexOp, Result, Tensor, D};
use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
use clap::Parser;
const IMG_SIZE: usize = 518;
const PATCH_SIZE: usize = 14;
const NUM_CLASSES: usize = 1000;
fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> {
pub fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> {
if bias {
candle_nn::linear(in_dim, out_dim, vb)
} else {
@ -26,13 +25,13 @@ fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<L
}
#[derive(Debug)]
struct MlpBlock {
pub struct MlpBlock {
lin1: Linear,
lin2: Linear,
}
impl MlpBlock {
fn new(embedding_dim: usize, mlp_dim: usize, vb: VarBuilder) -> Result<Self> {
pub fn new(embedding_dim: usize, mlp_dim: usize, vb: VarBuilder) -> Result<Self> {
let lin1 = candle_nn::linear(embedding_dim, mlp_dim, vb.pp("lin1"))?;
let lin2 = candle_nn::linear(mlp_dim, embedding_dim, vb.pp("lin2"))?;
Ok(Self { lin1, lin2 })
@ -45,347 +44,6 @@ impl Module for MlpBlock {
}
}
#[derive(Debug)]
struct PatchEmbed {
proj: candle_nn::Conv2d,
}
impl PatchEmbed {
fn new(
in_chans: usize,
embed_dim: usize,
k_size: usize,
stride: usize,
padding: usize,
vb: VarBuilder,
) -> Result<Self> {
let cfg = candle_nn::Conv2dConfig {
stride,
padding,
..Default::default()
};
let proj = candle_nn::conv2d(in_chans, embed_dim, k_size, cfg, vb.pp("proj"))?;
Ok(Self { proj })
}
}
impl Module for PatchEmbed {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
xs.apply(&self.proj)?.permute((0, 2, 3, 1))
}
}
#[derive(Debug)]
struct Attention {
qkv: Linear,
proj: Linear,
num_heads: usize,
scale: f64,
use_rel_pos: bool,
rel_pos_hw: Option<(Tensor, Tensor)>,
}
impl Attention {
fn new(
dim: usize,
num_heads: usize,
qkv_bias: bool,
use_rel_pos: bool,
window_size: usize,
vb: VarBuilder,
) -> Result<Self> {
let qkv = linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?;
let proj = linear(vb.pp("proj"), dim, dim, true)?;
let head_dim = dim / num_heads;
let scale = 1. / (head_dim as f64).sqrt();
let rel_pos_hw = if use_rel_pos {
let h = vb.get((2 * window_size - 1, head_dim), "rel_pos_h")?;
let w = vb.get((2 * window_size - 1, head_dim), "rel_pos_w")?;
Some((h, w))
} else {
None
};
Ok(Self {
qkv,
proj,
num_heads,
scale,
use_rel_pos,
rel_pos_hw,
})
}
}
impl Module for Attention {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let (b, h, w, c) = xs.dims4()?;
let qkv = self
.qkv
.forward(xs)?
.reshape((b, h * w, 3, self.num_heads, c / self.num_heads))?
.permute((2, 0, 3, 1, 4))?
.reshape((3, b * self.num_heads, h * w, c / self.num_heads))?;
let q = qkv.i(0)?;
let k = qkv.i(1)?;
let v = qkv.i(2)?;
let attn = (q * self.scale)?.matmul(&k.t()?)?;
if self.use_rel_pos {
todo!()
}
let attn = candle_nn::ops::softmax_last_dim(&attn)?;
let attn = attn
.matmul(&v)?
.reshape((b, self.num_heads, h, w, c / self.num_heads))?
.permute((0, 2, 3, 1, 4))?
.reshape((b, h, w, c / self.num_heads))?;
self.proj.forward(&attn)
}
}
#[derive(Debug)]
struct Block {
norm1: LayerNorm,
attn: Attention,
norm2: LayerNorm,
mlp: MlpBlock,
window_size: usize,
}
impl Block {
fn new(
dim: usize,
num_heads: usize,
qkv_bias: bool,
use_rel_pos: bool,
window_size: usize,
vb: VarBuilder,
) -> Result<Self> {
let norm1 = layer_norm(dim, 1e-5, vb.pp("norm1"))?;
let norm2 = layer_norm(dim, 1e-5, vb.pp("norm2"))?;
let attn = Attention::new(
dim,
num_heads,
qkv_bias,
use_rel_pos,
window_size,
vb.pp("attn"),
)?;
let mlp = MlpBlock::new(dim, dim * 4, vb.pp("mlp"))?;
Ok(Self {
norm1,
attn,
norm2,
mlp,
window_size,
})
}
}
impl Module for Block {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let shortcut = xs;
let xs = self.norm1.forward(xs)?;
if self.window_size > 0 {
todo!()
}
let xs = self.attn.forward(&xs)?;
if self.window_size > 0 {
todo!()
}
let xs = (xs + shortcut)?;
&xs + xs.apply(&self.norm2)?.apply(&self.mlp)?
}
}
#[derive(Debug)]
struct ImageEncoderViT {
img_size: usize,
patch_embed: PatchEmbed,
blocks: Vec<Block>,
neck_conv1: candle_nn::Conv2d,
neck_ln1: LayerNorm,
neck_conv2: candle_nn::Conv2d,
neck_ln2: LayerNorm,
pos_embed: Option<Tensor>,
}
impl ImageEncoderViT {
#[allow(clippy::too_many_arguments)]
fn new(
img_size: usize,
patch_size: usize,
in_chans: usize,
embed_dim: usize,
depth: usize,
num_heads: usize,
out_chans: usize,
qkv_bias: bool,
use_rel_pos: bool,
use_abs_pos: bool,
window_size: usize,
vb: VarBuilder,
) -> Result<Self> {
let patch_embed = PatchEmbed::new(
in_chans,
embed_dim,
patch_size,
patch_size,
0,
vb.pp("patch_embed"),
)?;
let mut blocks = Vec::with_capacity(depth);
let vb_b = vb.pp("blocks");
for i in 0..depth {
let block = Block::new(
embed_dim,
num_heads,
qkv_bias,
use_rel_pos,
window_size,
vb_b.pp(i),
)?;
blocks.push(block)
}
let neck_conv1 = candle_nn::conv2d_no_bias(
embed_dim,
out_chans,
1,
Default::default(),
vb.pp("neck.0"),
)?;
let neck_ln1 = layer_norm(out_chans, 1e-6, vb.pp("neck.1"))?;
let cfg = candle_nn::Conv2dConfig {
padding: 1,
..Default::default()
};
let neck_conv2 = candle_nn::conv2d_no_bias(out_chans, out_chans, 3, cfg, vb.pp("neck.2"))?;
let neck_ln2 = layer_norm(out_chans, 1e-6, vb.pp("neck.3"))?;
let pos_embed = if use_abs_pos {
let p = vb.get(
(1, img_size / patch_size, img_size / patch_size, embed_dim),
"pos_embed",
)?;
Some(p)
} else {
None
};
Ok(Self {
img_size,
patch_embed,
blocks,
neck_conv1,
neck_ln1,
neck_conv2,
neck_ln2,
pos_embed,
})
}
}
impl Module for ImageEncoderViT {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs = self.patch_embed.forward(xs)?;
let mut xs = match &self.pos_embed {
Some(pos_embed) => (xs + pos_embed)?,
None => xs,
};
for block in self.blocks.iter() {
xs = block.forward(&xs)?
}
xs.permute((0, 3, 1, 2))?
.apply(&self.neck_conv1)?
.apply(&self.neck_ln1)?
.apply(&self.neck_conv2)?
.apply(&self.neck_ln2)
}
}
#[derive(Debug)]
struct MlpMaskDecoder {
layers: Vec<Linear>,
sigmoid_output: bool,
}
impl MlpMaskDecoder {
fn new(
input_dim: usize,
hidden_dim: usize,
output_dim: usize,
num_layers: usize,
sigmoid_output: bool,
vb: VarBuilder,
) -> Result<Self> {
let mut layers = Vec::with_capacity(num_layers);
let vb = vb.pp("layers");
for i in 0..num_layers {
let in_dim = if i == 0 { input_dim } else { hidden_dim };
let out_dim = if i + 1 == num_layers {
output_dim
} else {
hidden_dim
};
let layer = linear(vb.pp(i), in_dim, out_dim, true)?;
layers.push(layer)
}
Ok(Self {
layers,
sigmoid_output,
})
}
}
impl Module for MlpMaskDecoder {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let mut xs = xs.clone();
for (i, layer) in self.layers.iter().enumerate() {
xs = layer.forward(&xs)?;
if i + 1 < self.layers.len() {
xs = xs.relu()?
}
}
if self.sigmoid_output {
candle_nn::ops::sigmoid(&xs)
} else {
Ok(xs)
}
}
}
#[derive(Debug)]
struct MaskDecoder {
iou_tokens: candle_nn::Embedding,
mask_tokens: candle_nn::Embedding,
iou_prediction_head: MlpMaskDecoder,
}
impl MaskDecoder {
fn new(
transformer_dim: usize,
num_multimask_outputs: usize,
iou_head_depth: usize,
iou_head_hidden_dim: usize,
vb: VarBuilder,
) -> Result<Self> {
let num_mask_tokens = num_multimask_outputs - 1;
let iou_prediction_head = MlpMaskDecoder::new(
transformer_dim,
iou_head_hidden_dim,
num_mask_tokens,
iou_head_depth,
false,
vb.pp("iou_prediction_head"),
)?;
let iou_tokens = candle_nn::embedding(1, transformer_dim, vb.pp("iou_tokens"))?;
let mask_tokens =
candle_nn::embedding(num_mask_tokens, transformer_dim, vb.pp("mask_tokens"))?;
Ok(Self {
iou_tokens,
mask_tokens,
iou_prediction_head,
})
}
}
/*
fn interpolate_pos_encoding(&self, xs: &Tensor, w: usize, h: usize) -> Result<Tensor> {
let npatch = xs.dim(1)? - 1;

View File

@ -0,0 +1,257 @@
use candle::{DType, IndexOp, Result, Tensor, D};
use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
#[derive(Debug)]
struct PatchEmbed {
proj: candle_nn::Conv2d,
}
impl PatchEmbed {
fn new(
in_chans: usize,
embed_dim: usize,
k_size: usize,
stride: usize,
padding: usize,
vb: VarBuilder,
) -> Result<Self> {
let cfg = candle_nn::Conv2dConfig {
stride,
padding,
..Default::default()
};
let proj = candle_nn::conv2d(in_chans, embed_dim, k_size, cfg, vb.pp("proj"))?;
Ok(Self { proj })
}
}
impl Module for PatchEmbed {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
xs.apply(&self.proj)?.permute((0, 2, 3, 1))
}
}
#[derive(Debug)]
struct Attention {
qkv: Linear,
proj: Linear,
num_heads: usize,
scale: f64,
use_rel_pos: bool,
rel_pos_hw: Option<(Tensor, Tensor)>,
}
impl Attention {
fn new(
dim: usize,
num_heads: usize,
qkv_bias: bool,
use_rel_pos: bool,
window_size: usize,
vb: VarBuilder,
) -> Result<Self> {
let qkv = crate::linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?;
let proj = crate::linear(vb.pp("proj"), dim, dim, true)?;
let head_dim = dim / num_heads;
let scale = 1. / (head_dim as f64).sqrt();
let rel_pos_hw = if use_rel_pos {
let h = vb.get((2 * window_size - 1, head_dim), "rel_pos_h")?;
let w = vb.get((2 * window_size - 1, head_dim), "rel_pos_w")?;
Some((h, w))
} else {
None
};
Ok(Self {
qkv,
proj,
num_heads,
scale,
use_rel_pos,
rel_pos_hw,
})
}
}
impl Module for Attention {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let (b, h, w, c) = xs.dims4()?;
let qkv = self
.qkv
.forward(xs)?
.reshape((b, h * w, 3, self.num_heads, c / self.num_heads))?
.permute((2, 0, 3, 1, 4))?
.reshape((3, b * self.num_heads, h * w, c / self.num_heads))?;
let q = qkv.i(0)?;
let k = qkv.i(1)?;
let v = qkv.i(2)?;
let attn = (q * self.scale)?.matmul(&k.t()?)?;
if self.use_rel_pos {
todo!()
}
let attn = candle_nn::ops::softmax_last_dim(&attn)?;
let attn = attn
.matmul(&v)?
.reshape((b, self.num_heads, h, w, c / self.num_heads))?
.permute((0, 2, 3, 1, 4))?
.reshape((b, h, w, c / self.num_heads))?;
self.proj.forward(&attn)
}
}
#[derive(Debug)]
struct Block {
norm1: LayerNorm,
attn: Attention,
norm2: LayerNorm,
mlp: crate::MlpBlock,
window_size: usize,
}
impl Block {
fn new(
dim: usize,
num_heads: usize,
qkv_bias: bool,
use_rel_pos: bool,
window_size: usize,
vb: VarBuilder,
) -> Result<Self> {
let norm1 = layer_norm(dim, 1e-5, vb.pp("norm1"))?;
let norm2 = layer_norm(dim, 1e-5, vb.pp("norm2"))?;
let attn = Attention::new(
dim,
num_heads,
qkv_bias,
use_rel_pos,
window_size,
vb.pp("attn"),
)?;
let mlp = crate::MlpBlock::new(dim, dim * 4, vb.pp("mlp"))?;
Ok(Self {
norm1,
attn,
norm2,
mlp,
window_size,
})
}
}
impl Module for Block {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let shortcut = xs;
let xs = self.norm1.forward(xs)?;
if self.window_size > 0 {
todo!()
}
let xs = self.attn.forward(&xs)?;
if self.window_size > 0 {
todo!()
}
let xs = (xs + shortcut)?;
&xs + xs.apply(&self.norm2)?.apply(&self.mlp)?
}
}
#[derive(Debug)]
struct ImageEncoderViT {
img_size: usize,
patch_embed: PatchEmbed,
blocks: Vec<Block>,
neck_conv1: candle_nn::Conv2d,
neck_ln1: LayerNorm,
neck_conv2: candle_nn::Conv2d,
neck_ln2: LayerNorm,
pos_embed: Option<Tensor>,
}
impl ImageEncoderViT {
#[allow(clippy::too_many_arguments)]
fn new(
img_size: usize,
patch_size: usize,
in_chans: usize,
embed_dim: usize,
depth: usize,
num_heads: usize,
out_chans: usize,
qkv_bias: bool,
use_rel_pos: bool,
use_abs_pos: bool,
window_size: usize,
vb: VarBuilder,
) -> Result<Self> {
let patch_embed = PatchEmbed::new(
in_chans,
embed_dim,
patch_size,
patch_size,
0,
vb.pp("patch_embed"),
)?;
let mut blocks = Vec::with_capacity(depth);
let vb_b = vb.pp("blocks");
for i in 0..depth {
let block = Block::new(
embed_dim,
num_heads,
qkv_bias,
use_rel_pos,
window_size,
vb_b.pp(i),
)?;
blocks.push(block)
}
let neck_conv1 = candle_nn::conv2d_no_bias(
embed_dim,
out_chans,
1,
Default::default(),
vb.pp("neck.0"),
)?;
let neck_ln1 = layer_norm(out_chans, 1e-6, vb.pp("neck.1"))?;
let cfg = candle_nn::Conv2dConfig {
padding: 1,
..Default::default()
};
let neck_conv2 = candle_nn::conv2d_no_bias(out_chans, out_chans, 3, cfg, vb.pp("neck.2"))?;
let neck_ln2 = layer_norm(out_chans, 1e-6, vb.pp("neck.3"))?;
let pos_embed = if use_abs_pos {
let p = vb.get(
(1, img_size / patch_size, img_size / patch_size, embed_dim),
"pos_embed",
)?;
Some(p)
} else {
None
};
Ok(Self {
img_size,
patch_embed,
blocks,
neck_conv1,
neck_ln1,
neck_conv2,
neck_ln2,
pos_embed,
})
}
}
impl Module for ImageEncoderViT {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs = self.patch_embed.forward(xs)?;
let mut xs = match &self.pos_embed {
Some(pos_embed) => (xs + pos_embed)?,
None => xs,
};
for block in self.blocks.iter() {
xs = block.forward(&xs)?
}
xs.permute((0, 3, 1, 2))?
.apply(&self.neck_conv1)?
.apply(&self.neck_ln1)?
.apply(&self.neck_conv2)?
.apply(&self.neck_ln2)
}
}

View File

@ -0,0 +1,222 @@
use candle::{DType, IndexOp, Result, Tensor, D};
use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
#[derive(Debug)]
struct MlpMaskDecoder {
layers: Vec<Linear>,
sigmoid_output: bool,
}
impl MlpMaskDecoder {
fn new(
input_dim: usize,
hidden_dim: usize,
output_dim: usize,
num_layers: usize,
sigmoid_output: bool,
vb: VarBuilder,
) -> Result<Self> {
let mut layers = Vec::with_capacity(num_layers);
let vb = vb.pp("layers");
for i in 0..num_layers {
let in_dim = if i == 0 { input_dim } else { hidden_dim };
let out_dim = if i + 1 == num_layers {
output_dim
} else {
hidden_dim
};
let layer = crate::linear(vb.pp(i), in_dim, out_dim, true)?;
layers.push(layer)
}
Ok(Self {
layers,
sigmoid_output,
})
}
}
impl Module for MlpMaskDecoder {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let mut xs = xs.clone();
for (i, layer) in self.layers.iter().enumerate() {
xs = layer.forward(&xs)?;
if i + 1 < self.layers.len() {
xs = xs.relu()?
}
}
if self.sigmoid_output {
candle_nn::ops::sigmoid(&xs)
} else {
Ok(xs)
}
}
}
#[derive(Debug)]
struct MaskDecoder {
iou_token: candle_nn::Embedding,
mask_tokens: candle_nn::Embedding,
iou_prediction_head: MlpMaskDecoder,
output_upscaling_conv1: candle_nn::ConvTranspose2d,
output_upscaling_ln: LayerNorm,
output_upscaling_conv2: candle_nn::ConvTranspose2d,
num_mask_tokens: usize,
output_hypernetworks_mlps: Vec<MlpMaskDecoder>,
}
impl MaskDecoder {
fn new(
transformer_dim: usize,
num_multimask_outputs: usize,
iou_head_depth: usize,
iou_head_hidden_dim: usize,
vb: VarBuilder,
) -> Result<Self> {
let num_mask_tokens = num_multimask_outputs - 1;
let iou_prediction_head = MlpMaskDecoder::new(
transformer_dim,
iou_head_hidden_dim,
num_mask_tokens,
iou_head_depth,
false,
vb.pp("iou_prediction_head"),
)?;
let iou_token = candle_nn::embedding(1, transformer_dim, vb.pp("iou_token"))?;
let mask_tokens =
candle_nn::embedding(num_mask_tokens, transformer_dim, vb.pp("mask_tokens"))?;
let cfg = candle_nn::ConvTranspose2dConfig {
stride: 2,
..Default::default()
};
let output_upscaling_conv1 = candle_nn::conv_transpose2d(
transformer_dim,
transformer_dim / 4,
2,
cfg,
vb.pp("output_upscaling.0"),
)?;
let output_upscaling_ln =
layer_norm(transformer_dim / 4, 1e-6, vb.pp("output_upscaling.1"))?;
let output_upscaling_conv2 = candle_nn::conv_transpose2d(
transformer_dim / 4,
transformer_dim / 8,
2,
cfg,
vb.pp("output_upscaling.3"),
)?;
let mut output_hypernetworks_mlps = Vec::with_capacity(num_mask_tokens);
let vb_o = vb.pp("output_hypernetworks_mlps");
for i in 0..num_mask_tokens {
let mlp = MlpMaskDecoder::new(
transformer_dim,
transformer_dim,
transformer_dim / 8,
3,
false,
vb_o.pp(i),
)?;
output_hypernetworks_mlps.push(mlp)
}
Ok(Self {
iou_token,
mask_tokens,
iou_prediction_head,
output_upscaling_conv1,
output_upscaling_ln,
output_upscaling_conv2,
num_mask_tokens,
output_hypernetworks_mlps,
})
}
fn forward(
&self,
image_embeddings: &Tensor,
image_pe: &Tensor,
sparse_prompt_embeddings: &Tensor,
dense_prompt_embeddings: &Tensor,
multimask_output: bool,
) -> Result<(Tensor, Tensor)> {
let (masks, iou_pred) = self.predict_masks(
image_embeddings,
image_pe,
sparse_prompt_embeddings,
dense_prompt_embeddings,
)?;
let masks = if multimask_output {
masks.i((.., 1..))?
} else {
masks.i((.., 0..1))?
};
let iou_pred = if multimask_output {
iou_pred.i((.., 1..))?
} else {
iou_pred.i((.., 0..1))?
};
Ok((masks, iou_pred))
}
fn predict_masks(
&self,
image_embeddings: &Tensor,
image_pe: &Tensor,
sparse_prompt_embeddings: &Tensor,
dense_prompt_embeddings: &Tensor,
) -> Result<(Tensor, Tensor)> {
// Concatenate ouput tokens.
let output_tokens = Tensor::cat(
&[self.iou_token.embeddings(), self.mask_tokens.embeddings()],
0,
)?;
let (d1, d2) = output_tokens.dims2()?;
let output_tokens =
output_tokens
.unsqueeze(0)?
.expand((sparse_prompt_embeddings.dim(0)?, d1, d2))?;
let tokens = Tensor::cat(&[&output_tokens, sparse_prompt_embeddings], 1)?;
// Expand per-image data in batch direction to be per mask
let src = repeat_interleave(image_embeddings, tokens.dim(0)?, 0)?;
let src = (src + dense_prompt_embeddings)?;
let pos_src = repeat_interleave(image_pe, tokens.dim(0)?, 0)?;
let (b, c, h, w) = src.dims4()?;
// Run the transformer
let (hs, src) = run_transformer(&src, &pos_src, &tokens)?;
let iou_token_out = hs.i((.., 0))?;
let mask_tokens_out = hs.i((.., 1, 1 + self.num_mask_tokens))?;
// Upscale mask embeddings and predict masks using the masks tokens.
let src = src.transpose(1, 2)?.reshape((b, c, h, w))?;
let upscaled_embedding = self
.output_upscaling_conv1
.forward(&src)?
.apply(&self.output_upscaling_ln)?
.gelu()?
.apply(&self.output_upscaling_conv2)?
.gelu()?;
let mut hyper_in_list = Vec::with_capacity(self.num_mask_tokens);
for (i, mlp) in self.output_hypernetworks_mlps.iter().enumerate() {
let h = mlp.forward(&mask_tokens_out.i((.., i))?)?;
hyper_in_list.push(h)
}
let hyper_in = Tensor::stack(hyper_in_list.as_slice(), 1)?;
let (b, c, h, w) = upscaled_embedding.dims4()?;
let masks = hyper_in
.matmul(&upscaled_embedding.reshape((b, c, h * w))?)?
.reshape((b, 0, h, w))?;
// Generate mask quality predictions.
let iou_pred = self.iou_prediction_head.forward(&iou_token_out)?;
Ok((masks, iou_pred))
}
}
// Equivalent to torch.repeat_interleave
fn repeat_interleave(_img: &Tensor, _repeats: usize, _dim: usize) -> Result<Tensor> {
todo!()
}
fn run_transformer(_src: &Tensor, _pos: &Tensor, _tokens: &Tensor) -> Result<(Tensor, Tensor)> {
todo!()
}

View File

@ -0,0 +1,77 @@
use candle::{DType, IndexOp, Result, Tensor, D};
use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
#[derive(Debug)]
struct Attention {
q_proj: Linear,
k_proj: Linear,
v_proj: Linear,
out_proj: Linear,
internal_dim: usize,
num_heads: usize,
}
impl Attention {
fn new(
embedding_dim: usize,
num_heads: usize,
downsample_rate: usize,
vb: VarBuilder,
) -> Result<Self> {
let internal_dim = embedding_dim / downsample_rate;
let q_proj = candle_nn::linear(embedding_dim, internal_dim, vb.pp("q_proj"))?;
let k_proj = candle_nn::linear(embedding_dim, internal_dim, vb.pp("k_proj"))?;
let v_proj = candle_nn::linear(embedding_dim, internal_dim, vb.pp("v_proj"))?;
let out_proj = candle_nn::linear(internal_dim, embedding_dim, vb.pp("out_proj"))?;
Ok(Self {
q_proj,
k_proj,
v_proj,
out_proj,
internal_dim,
num_heads,
})
}
fn separate_heads(&self, x: &Tensor) -> Result<Tensor> {
let (b, n, c) = x.dims3()?;
x.reshape((b, n, self.num_heads, c / self.num_heads))?
.transpose(1, 2)
}
fn recombine_heads(&self, x: &Tensor) -> Result<Tensor> {
let (b, n_heads, n_tokens, c_per_head) = x.dims4()?;
x.transpose(1, 2)?
.reshape((b, n_tokens, n_heads * c_per_head))
}
fn forward(&self, q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
let q = self.q_proj.forward(q)?;
let k = self.k_proj.forward(k)?;
let v = self.v_proj.forward(v)?;
let q = self.separate_heads(&q)?;
let k = self.separate_heads(&k)?;
let v = self.separate_heads(&v)?;
let (_, _, _, c_per_head) = q.dims4()?;
let attn = (q.matmul(&k.t()?)? / (c_per_head as f64).sqrt())?;
let attn = candle_nn::ops::softmax_last_dim(&attn)?;
let out = attn.matmul(&v)?;
self.recombine_heads(&out)?.apply(&self.out_proj)
}
}
#[derive(Debug)]
struct TwoWayAttentionBlock {
self_attn: Attention,
norm1: LayerNorm,
cross_attn_token_to_image: Attention,
norm2: LayerNorm,
mlp: crate::MlpBlock,
norm3: LayerNorm,
norm4: LayerNorm,
cross_attn_image_to_token: Attention,
skip_first_layer_pe: bool,
}

View File

@ -130,6 +130,17 @@ pub struct ConvTranspose2dConfig {
// TODO: support groups.
}
impl Default for ConvTranspose2dConfig {
fn default() -> Self {
Self {
padding: 0,
output_padding: 0,
stride: 1,
dilation: 1,
}
}
}
#[derive(Debug)]
pub struct ConvTranspose2d {
weight: Tensor,