mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
More segment-anything. (#763)
* More segment-anything. * Split the model in multiple files. * Start adding the transformer. * Add the attention block. * Move the MLP Block.
This commit is contained in:
@ -8,16 +8,15 @@ extern crate intel_mkl_src;
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use clap::Parser;
|
||||
mod model_image_encoder;
|
||||
mod model_mask_decoder;
|
||||
mod model_transformer;
|
||||
|
||||
use candle::{DType, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
|
||||
use clap::Parser;
|
||||
|
||||
const IMG_SIZE: usize = 518;
|
||||
const PATCH_SIZE: usize = 14;
|
||||
const NUM_CLASSES: usize = 1000;
|
||||
|
||||
fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> {
|
||||
pub fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> {
|
||||
if bias {
|
||||
candle_nn::linear(in_dim, out_dim, vb)
|
||||
} else {
|
||||
@ -26,13 +25,13 @@ fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<L
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct MlpBlock {
|
||||
pub struct MlpBlock {
|
||||
lin1: Linear,
|
||||
lin2: Linear,
|
||||
}
|
||||
|
||||
impl MlpBlock {
|
||||
fn new(embedding_dim: usize, mlp_dim: usize, vb: VarBuilder) -> Result<Self> {
|
||||
pub fn new(embedding_dim: usize, mlp_dim: usize, vb: VarBuilder) -> Result<Self> {
|
||||
let lin1 = candle_nn::linear(embedding_dim, mlp_dim, vb.pp("lin1"))?;
|
||||
let lin2 = candle_nn::linear(mlp_dim, embedding_dim, vb.pp("lin2"))?;
|
||||
Ok(Self { lin1, lin2 })
|
||||
@ -45,347 +44,6 @@ impl Module for MlpBlock {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct PatchEmbed {
|
||||
proj: candle_nn::Conv2d,
|
||||
}
|
||||
|
||||
impl PatchEmbed {
|
||||
fn new(
|
||||
in_chans: usize,
|
||||
embed_dim: usize,
|
||||
k_size: usize,
|
||||
stride: usize,
|
||||
padding: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let cfg = candle_nn::Conv2dConfig {
|
||||
stride,
|
||||
padding,
|
||||
..Default::default()
|
||||
};
|
||||
let proj = candle_nn::conv2d(in_chans, embed_dim, k_size, cfg, vb.pp("proj"))?;
|
||||
Ok(Self { proj })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for PatchEmbed {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
xs.apply(&self.proj)?.permute((0, 2, 3, 1))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Attention {
|
||||
qkv: Linear,
|
||||
proj: Linear,
|
||||
num_heads: usize,
|
||||
scale: f64,
|
||||
use_rel_pos: bool,
|
||||
rel_pos_hw: Option<(Tensor, Tensor)>,
|
||||
}
|
||||
|
||||
impl Attention {
|
||||
fn new(
|
||||
dim: usize,
|
||||
num_heads: usize,
|
||||
qkv_bias: bool,
|
||||
use_rel_pos: bool,
|
||||
window_size: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let qkv = linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?;
|
||||
let proj = linear(vb.pp("proj"), dim, dim, true)?;
|
||||
let head_dim = dim / num_heads;
|
||||
let scale = 1. / (head_dim as f64).sqrt();
|
||||
let rel_pos_hw = if use_rel_pos {
|
||||
let h = vb.get((2 * window_size - 1, head_dim), "rel_pos_h")?;
|
||||
let w = vb.get((2 * window_size - 1, head_dim), "rel_pos_w")?;
|
||||
Some((h, w))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(Self {
|
||||
qkv,
|
||||
proj,
|
||||
num_heads,
|
||||
scale,
|
||||
use_rel_pos,
|
||||
rel_pos_hw,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Attention {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let (b, h, w, c) = xs.dims4()?;
|
||||
let qkv = self
|
||||
.qkv
|
||||
.forward(xs)?
|
||||
.reshape((b, h * w, 3, self.num_heads, c / self.num_heads))?
|
||||
.permute((2, 0, 3, 1, 4))?
|
||||
.reshape((3, b * self.num_heads, h * w, c / self.num_heads))?;
|
||||
let q = qkv.i(0)?;
|
||||
let k = qkv.i(1)?;
|
||||
let v = qkv.i(2)?;
|
||||
let attn = (q * self.scale)?.matmul(&k.t()?)?;
|
||||
if self.use_rel_pos {
|
||||
todo!()
|
||||
}
|
||||
let attn = candle_nn::ops::softmax_last_dim(&attn)?;
|
||||
let attn = attn
|
||||
.matmul(&v)?
|
||||
.reshape((b, self.num_heads, h, w, c / self.num_heads))?
|
||||
.permute((0, 2, 3, 1, 4))?
|
||||
.reshape((b, h, w, c / self.num_heads))?;
|
||||
self.proj.forward(&attn)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Block {
|
||||
norm1: LayerNorm,
|
||||
attn: Attention,
|
||||
norm2: LayerNorm,
|
||||
mlp: MlpBlock,
|
||||
window_size: usize,
|
||||
}
|
||||
|
||||
impl Block {
|
||||
fn new(
|
||||
dim: usize,
|
||||
num_heads: usize,
|
||||
qkv_bias: bool,
|
||||
use_rel_pos: bool,
|
||||
window_size: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let norm1 = layer_norm(dim, 1e-5, vb.pp("norm1"))?;
|
||||
let norm2 = layer_norm(dim, 1e-5, vb.pp("norm2"))?;
|
||||
let attn = Attention::new(
|
||||
dim,
|
||||
num_heads,
|
||||
qkv_bias,
|
||||
use_rel_pos,
|
||||
window_size,
|
||||
vb.pp("attn"),
|
||||
)?;
|
||||
let mlp = MlpBlock::new(dim, dim * 4, vb.pp("mlp"))?;
|
||||
Ok(Self {
|
||||
norm1,
|
||||
attn,
|
||||
norm2,
|
||||
mlp,
|
||||
window_size,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Block {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let shortcut = xs;
|
||||
let xs = self.norm1.forward(xs)?;
|
||||
if self.window_size > 0 {
|
||||
todo!()
|
||||
}
|
||||
let xs = self.attn.forward(&xs)?;
|
||||
if self.window_size > 0 {
|
||||
todo!()
|
||||
}
|
||||
let xs = (xs + shortcut)?;
|
||||
&xs + xs.apply(&self.norm2)?.apply(&self.mlp)?
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct ImageEncoderViT {
|
||||
img_size: usize,
|
||||
patch_embed: PatchEmbed,
|
||||
blocks: Vec<Block>,
|
||||
neck_conv1: candle_nn::Conv2d,
|
||||
neck_ln1: LayerNorm,
|
||||
neck_conv2: candle_nn::Conv2d,
|
||||
neck_ln2: LayerNorm,
|
||||
pos_embed: Option<Tensor>,
|
||||
}
|
||||
|
||||
impl ImageEncoderViT {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn new(
|
||||
img_size: usize,
|
||||
patch_size: usize,
|
||||
in_chans: usize,
|
||||
embed_dim: usize,
|
||||
depth: usize,
|
||||
num_heads: usize,
|
||||
out_chans: usize,
|
||||
qkv_bias: bool,
|
||||
use_rel_pos: bool,
|
||||
use_abs_pos: bool,
|
||||
window_size: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let patch_embed = PatchEmbed::new(
|
||||
in_chans,
|
||||
embed_dim,
|
||||
patch_size,
|
||||
patch_size,
|
||||
0,
|
||||
vb.pp("patch_embed"),
|
||||
)?;
|
||||
let mut blocks = Vec::with_capacity(depth);
|
||||
let vb_b = vb.pp("blocks");
|
||||
for i in 0..depth {
|
||||
let block = Block::new(
|
||||
embed_dim,
|
||||
num_heads,
|
||||
qkv_bias,
|
||||
use_rel_pos,
|
||||
window_size,
|
||||
vb_b.pp(i),
|
||||
)?;
|
||||
blocks.push(block)
|
||||
}
|
||||
let neck_conv1 = candle_nn::conv2d_no_bias(
|
||||
embed_dim,
|
||||
out_chans,
|
||||
1,
|
||||
Default::default(),
|
||||
vb.pp("neck.0"),
|
||||
)?;
|
||||
let neck_ln1 = layer_norm(out_chans, 1e-6, vb.pp("neck.1"))?;
|
||||
let cfg = candle_nn::Conv2dConfig {
|
||||
padding: 1,
|
||||
..Default::default()
|
||||
};
|
||||
let neck_conv2 = candle_nn::conv2d_no_bias(out_chans, out_chans, 3, cfg, vb.pp("neck.2"))?;
|
||||
let neck_ln2 = layer_norm(out_chans, 1e-6, vb.pp("neck.3"))?;
|
||||
let pos_embed = if use_abs_pos {
|
||||
let p = vb.get(
|
||||
(1, img_size / patch_size, img_size / patch_size, embed_dim),
|
||||
"pos_embed",
|
||||
)?;
|
||||
Some(p)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(Self {
|
||||
img_size,
|
||||
patch_embed,
|
||||
blocks,
|
||||
neck_conv1,
|
||||
neck_ln1,
|
||||
neck_conv2,
|
||||
neck_ln2,
|
||||
pos_embed,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for ImageEncoderViT {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let xs = self.patch_embed.forward(xs)?;
|
||||
let mut xs = match &self.pos_embed {
|
||||
Some(pos_embed) => (xs + pos_embed)?,
|
||||
None => xs,
|
||||
};
|
||||
for block in self.blocks.iter() {
|
||||
xs = block.forward(&xs)?
|
||||
}
|
||||
xs.permute((0, 3, 1, 2))?
|
||||
.apply(&self.neck_conv1)?
|
||||
.apply(&self.neck_ln1)?
|
||||
.apply(&self.neck_conv2)?
|
||||
.apply(&self.neck_ln2)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct MlpMaskDecoder {
|
||||
layers: Vec<Linear>,
|
||||
sigmoid_output: bool,
|
||||
}
|
||||
|
||||
impl MlpMaskDecoder {
|
||||
fn new(
|
||||
input_dim: usize,
|
||||
hidden_dim: usize,
|
||||
output_dim: usize,
|
||||
num_layers: usize,
|
||||
sigmoid_output: bool,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let mut layers = Vec::with_capacity(num_layers);
|
||||
let vb = vb.pp("layers");
|
||||
for i in 0..num_layers {
|
||||
let in_dim = if i == 0 { input_dim } else { hidden_dim };
|
||||
let out_dim = if i + 1 == num_layers {
|
||||
output_dim
|
||||
} else {
|
||||
hidden_dim
|
||||
};
|
||||
let layer = linear(vb.pp(i), in_dim, out_dim, true)?;
|
||||
layers.push(layer)
|
||||
}
|
||||
Ok(Self {
|
||||
layers,
|
||||
sigmoid_output,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for MlpMaskDecoder {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let mut xs = xs.clone();
|
||||
for (i, layer) in self.layers.iter().enumerate() {
|
||||
xs = layer.forward(&xs)?;
|
||||
if i + 1 < self.layers.len() {
|
||||
xs = xs.relu()?
|
||||
}
|
||||
}
|
||||
if self.sigmoid_output {
|
||||
candle_nn::ops::sigmoid(&xs)
|
||||
} else {
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct MaskDecoder {
|
||||
iou_tokens: candle_nn::Embedding,
|
||||
mask_tokens: candle_nn::Embedding,
|
||||
iou_prediction_head: MlpMaskDecoder,
|
||||
}
|
||||
|
||||
impl MaskDecoder {
|
||||
fn new(
|
||||
transformer_dim: usize,
|
||||
num_multimask_outputs: usize,
|
||||
iou_head_depth: usize,
|
||||
iou_head_hidden_dim: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let num_mask_tokens = num_multimask_outputs - 1;
|
||||
let iou_prediction_head = MlpMaskDecoder::new(
|
||||
transformer_dim,
|
||||
iou_head_hidden_dim,
|
||||
num_mask_tokens,
|
||||
iou_head_depth,
|
||||
false,
|
||||
vb.pp("iou_prediction_head"),
|
||||
)?;
|
||||
let iou_tokens = candle_nn::embedding(1, transformer_dim, vb.pp("iou_tokens"))?;
|
||||
let mask_tokens =
|
||||
candle_nn::embedding(num_mask_tokens, transformer_dim, vb.pp("mask_tokens"))?;
|
||||
Ok(Self {
|
||||
iou_tokens,
|
||||
mask_tokens,
|
||||
iou_prediction_head,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
fn interpolate_pos_encoding(&self, xs: &Tensor, w: usize, h: usize) -> Result<Tensor> {
|
||||
let npatch = xs.dim(1)? - 1;
|
||||
|
257
candle-examples/examples/segment-anything/model_image_encoder.rs
Normal file
257
candle-examples/examples/segment-anything/model_image_encoder.rs
Normal file
@ -0,0 +1,257 @@
|
||||
use candle::{DType, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
|
||||
|
||||
#[derive(Debug)]
|
||||
struct PatchEmbed {
|
||||
proj: candle_nn::Conv2d,
|
||||
}
|
||||
|
||||
impl PatchEmbed {
|
||||
fn new(
|
||||
in_chans: usize,
|
||||
embed_dim: usize,
|
||||
k_size: usize,
|
||||
stride: usize,
|
||||
padding: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let cfg = candle_nn::Conv2dConfig {
|
||||
stride,
|
||||
padding,
|
||||
..Default::default()
|
||||
};
|
||||
let proj = candle_nn::conv2d(in_chans, embed_dim, k_size, cfg, vb.pp("proj"))?;
|
||||
Ok(Self { proj })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for PatchEmbed {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
xs.apply(&self.proj)?.permute((0, 2, 3, 1))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Attention {
|
||||
qkv: Linear,
|
||||
proj: Linear,
|
||||
num_heads: usize,
|
||||
scale: f64,
|
||||
use_rel_pos: bool,
|
||||
rel_pos_hw: Option<(Tensor, Tensor)>,
|
||||
}
|
||||
|
||||
impl Attention {
|
||||
fn new(
|
||||
dim: usize,
|
||||
num_heads: usize,
|
||||
qkv_bias: bool,
|
||||
use_rel_pos: bool,
|
||||
window_size: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let qkv = crate::linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?;
|
||||
let proj = crate::linear(vb.pp("proj"), dim, dim, true)?;
|
||||
let head_dim = dim / num_heads;
|
||||
let scale = 1. / (head_dim as f64).sqrt();
|
||||
let rel_pos_hw = if use_rel_pos {
|
||||
let h = vb.get((2 * window_size - 1, head_dim), "rel_pos_h")?;
|
||||
let w = vb.get((2 * window_size - 1, head_dim), "rel_pos_w")?;
|
||||
Some((h, w))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(Self {
|
||||
qkv,
|
||||
proj,
|
||||
num_heads,
|
||||
scale,
|
||||
use_rel_pos,
|
||||
rel_pos_hw,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Attention {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let (b, h, w, c) = xs.dims4()?;
|
||||
let qkv = self
|
||||
.qkv
|
||||
.forward(xs)?
|
||||
.reshape((b, h * w, 3, self.num_heads, c / self.num_heads))?
|
||||
.permute((2, 0, 3, 1, 4))?
|
||||
.reshape((3, b * self.num_heads, h * w, c / self.num_heads))?;
|
||||
let q = qkv.i(0)?;
|
||||
let k = qkv.i(1)?;
|
||||
let v = qkv.i(2)?;
|
||||
let attn = (q * self.scale)?.matmul(&k.t()?)?;
|
||||
if self.use_rel_pos {
|
||||
todo!()
|
||||
}
|
||||
let attn = candle_nn::ops::softmax_last_dim(&attn)?;
|
||||
let attn = attn
|
||||
.matmul(&v)?
|
||||
.reshape((b, self.num_heads, h, w, c / self.num_heads))?
|
||||
.permute((0, 2, 3, 1, 4))?
|
||||
.reshape((b, h, w, c / self.num_heads))?;
|
||||
self.proj.forward(&attn)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Block {
|
||||
norm1: LayerNorm,
|
||||
attn: Attention,
|
||||
norm2: LayerNorm,
|
||||
mlp: crate::MlpBlock,
|
||||
window_size: usize,
|
||||
}
|
||||
|
||||
impl Block {
|
||||
fn new(
|
||||
dim: usize,
|
||||
num_heads: usize,
|
||||
qkv_bias: bool,
|
||||
use_rel_pos: bool,
|
||||
window_size: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let norm1 = layer_norm(dim, 1e-5, vb.pp("norm1"))?;
|
||||
let norm2 = layer_norm(dim, 1e-5, vb.pp("norm2"))?;
|
||||
let attn = Attention::new(
|
||||
dim,
|
||||
num_heads,
|
||||
qkv_bias,
|
||||
use_rel_pos,
|
||||
window_size,
|
||||
vb.pp("attn"),
|
||||
)?;
|
||||
let mlp = crate::MlpBlock::new(dim, dim * 4, vb.pp("mlp"))?;
|
||||
Ok(Self {
|
||||
norm1,
|
||||
attn,
|
||||
norm2,
|
||||
mlp,
|
||||
window_size,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Block {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let shortcut = xs;
|
||||
let xs = self.norm1.forward(xs)?;
|
||||
if self.window_size > 0 {
|
||||
todo!()
|
||||
}
|
||||
let xs = self.attn.forward(&xs)?;
|
||||
if self.window_size > 0 {
|
||||
todo!()
|
||||
}
|
||||
let xs = (xs + shortcut)?;
|
||||
&xs + xs.apply(&self.norm2)?.apply(&self.mlp)?
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct ImageEncoderViT {
|
||||
img_size: usize,
|
||||
patch_embed: PatchEmbed,
|
||||
blocks: Vec<Block>,
|
||||
neck_conv1: candle_nn::Conv2d,
|
||||
neck_ln1: LayerNorm,
|
||||
neck_conv2: candle_nn::Conv2d,
|
||||
neck_ln2: LayerNorm,
|
||||
pos_embed: Option<Tensor>,
|
||||
}
|
||||
|
||||
impl ImageEncoderViT {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn new(
|
||||
img_size: usize,
|
||||
patch_size: usize,
|
||||
in_chans: usize,
|
||||
embed_dim: usize,
|
||||
depth: usize,
|
||||
num_heads: usize,
|
||||
out_chans: usize,
|
||||
qkv_bias: bool,
|
||||
use_rel_pos: bool,
|
||||
use_abs_pos: bool,
|
||||
window_size: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let patch_embed = PatchEmbed::new(
|
||||
in_chans,
|
||||
embed_dim,
|
||||
patch_size,
|
||||
patch_size,
|
||||
0,
|
||||
vb.pp("patch_embed"),
|
||||
)?;
|
||||
let mut blocks = Vec::with_capacity(depth);
|
||||
let vb_b = vb.pp("blocks");
|
||||
for i in 0..depth {
|
||||
let block = Block::new(
|
||||
embed_dim,
|
||||
num_heads,
|
||||
qkv_bias,
|
||||
use_rel_pos,
|
||||
window_size,
|
||||
vb_b.pp(i),
|
||||
)?;
|
||||
blocks.push(block)
|
||||
}
|
||||
let neck_conv1 = candle_nn::conv2d_no_bias(
|
||||
embed_dim,
|
||||
out_chans,
|
||||
1,
|
||||
Default::default(),
|
||||
vb.pp("neck.0"),
|
||||
)?;
|
||||
let neck_ln1 = layer_norm(out_chans, 1e-6, vb.pp("neck.1"))?;
|
||||
let cfg = candle_nn::Conv2dConfig {
|
||||
padding: 1,
|
||||
..Default::default()
|
||||
};
|
||||
let neck_conv2 = candle_nn::conv2d_no_bias(out_chans, out_chans, 3, cfg, vb.pp("neck.2"))?;
|
||||
let neck_ln2 = layer_norm(out_chans, 1e-6, vb.pp("neck.3"))?;
|
||||
let pos_embed = if use_abs_pos {
|
||||
let p = vb.get(
|
||||
(1, img_size / patch_size, img_size / patch_size, embed_dim),
|
||||
"pos_embed",
|
||||
)?;
|
||||
Some(p)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(Self {
|
||||
img_size,
|
||||
patch_embed,
|
||||
blocks,
|
||||
neck_conv1,
|
||||
neck_ln1,
|
||||
neck_conv2,
|
||||
neck_ln2,
|
||||
pos_embed,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for ImageEncoderViT {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let xs = self.patch_embed.forward(xs)?;
|
||||
let mut xs = match &self.pos_embed {
|
||||
Some(pos_embed) => (xs + pos_embed)?,
|
||||
None => xs,
|
||||
};
|
||||
for block in self.blocks.iter() {
|
||||
xs = block.forward(&xs)?
|
||||
}
|
||||
xs.permute((0, 3, 1, 2))?
|
||||
.apply(&self.neck_conv1)?
|
||||
.apply(&self.neck_ln1)?
|
||||
.apply(&self.neck_conv2)?
|
||||
.apply(&self.neck_ln2)
|
||||
}
|
||||
}
|
222
candle-examples/examples/segment-anything/model_mask_decoder.rs
Normal file
222
candle-examples/examples/segment-anything/model_mask_decoder.rs
Normal file
@ -0,0 +1,222 @@
|
||||
use candle::{DType, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
|
||||
|
||||
#[derive(Debug)]
|
||||
struct MlpMaskDecoder {
|
||||
layers: Vec<Linear>,
|
||||
sigmoid_output: bool,
|
||||
}
|
||||
|
||||
impl MlpMaskDecoder {
|
||||
fn new(
|
||||
input_dim: usize,
|
||||
hidden_dim: usize,
|
||||
output_dim: usize,
|
||||
num_layers: usize,
|
||||
sigmoid_output: bool,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let mut layers = Vec::with_capacity(num_layers);
|
||||
let vb = vb.pp("layers");
|
||||
for i in 0..num_layers {
|
||||
let in_dim = if i == 0 { input_dim } else { hidden_dim };
|
||||
let out_dim = if i + 1 == num_layers {
|
||||
output_dim
|
||||
} else {
|
||||
hidden_dim
|
||||
};
|
||||
let layer = crate::linear(vb.pp(i), in_dim, out_dim, true)?;
|
||||
layers.push(layer)
|
||||
}
|
||||
Ok(Self {
|
||||
layers,
|
||||
sigmoid_output,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for MlpMaskDecoder {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let mut xs = xs.clone();
|
||||
for (i, layer) in self.layers.iter().enumerate() {
|
||||
xs = layer.forward(&xs)?;
|
||||
if i + 1 < self.layers.len() {
|
||||
xs = xs.relu()?
|
||||
}
|
||||
}
|
||||
if self.sigmoid_output {
|
||||
candle_nn::ops::sigmoid(&xs)
|
||||
} else {
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct MaskDecoder {
|
||||
iou_token: candle_nn::Embedding,
|
||||
mask_tokens: candle_nn::Embedding,
|
||||
iou_prediction_head: MlpMaskDecoder,
|
||||
output_upscaling_conv1: candle_nn::ConvTranspose2d,
|
||||
output_upscaling_ln: LayerNorm,
|
||||
output_upscaling_conv2: candle_nn::ConvTranspose2d,
|
||||
num_mask_tokens: usize,
|
||||
output_hypernetworks_mlps: Vec<MlpMaskDecoder>,
|
||||
}
|
||||
|
||||
impl MaskDecoder {
|
||||
fn new(
|
||||
transformer_dim: usize,
|
||||
num_multimask_outputs: usize,
|
||||
iou_head_depth: usize,
|
||||
iou_head_hidden_dim: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let num_mask_tokens = num_multimask_outputs - 1;
|
||||
let iou_prediction_head = MlpMaskDecoder::new(
|
||||
transformer_dim,
|
||||
iou_head_hidden_dim,
|
||||
num_mask_tokens,
|
||||
iou_head_depth,
|
||||
false,
|
||||
vb.pp("iou_prediction_head"),
|
||||
)?;
|
||||
let iou_token = candle_nn::embedding(1, transformer_dim, vb.pp("iou_token"))?;
|
||||
let mask_tokens =
|
||||
candle_nn::embedding(num_mask_tokens, transformer_dim, vb.pp("mask_tokens"))?;
|
||||
let cfg = candle_nn::ConvTranspose2dConfig {
|
||||
stride: 2,
|
||||
..Default::default()
|
||||
};
|
||||
let output_upscaling_conv1 = candle_nn::conv_transpose2d(
|
||||
transformer_dim,
|
||||
transformer_dim / 4,
|
||||
2,
|
||||
cfg,
|
||||
vb.pp("output_upscaling.0"),
|
||||
)?;
|
||||
let output_upscaling_ln =
|
||||
layer_norm(transformer_dim / 4, 1e-6, vb.pp("output_upscaling.1"))?;
|
||||
let output_upscaling_conv2 = candle_nn::conv_transpose2d(
|
||||
transformer_dim / 4,
|
||||
transformer_dim / 8,
|
||||
2,
|
||||
cfg,
|
||||
vb.pp("output_upscaling.3"),
|
||||
)?;
|
||||
let mut output_hypernetworks_mlps = Vec::with_capacity(num_mask_tokens);
|
||||
let vb_o = vb.pp("output_hypernetworks_mlps");
|
||||
for i in 0..num_mask_tokens {
|
||||
let mlp = MlpMaskDecoder::new(
|
||||
transformer_dim,
|
||||
transformer_dim,
|
||||
transformer_dim / 8,
|
||||
3,
|
||||
false,
|
||||
vb_o.pp(i),
|
||||
)?;
|
||||
output_hypernetworks_mlps.push(mlp)
|
||||
}
|
||||
Ok(Self {
|
||||
iou_token,
|
||||
mask_tokens,
|
||||
iou_prediction_head,
|
||||
output_upscaling_conv1,
|
||||
output_upscaling_ln,
|
||||
output_upscaling_conv2,
|
||||
num_mask_tokens,
|
||||
output_hypernetworks_mlps,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&self,
|
||||
image_embeddings: &Tensor,
|
||||
image_pe: &Tensor,
|
||||
sparse_prompt_embeddings: &Tensor,
|
||||
dense_prompt_embeddings: &Tensor,
|
||||
multimask_output: bool,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
let (masks, iou_pred) = self.predict_masks(
|
||||
image_embeddings,
|
||||
image_pe,
|
||||
sparse_prompt_embeddings,
|
||||
dense_prompt_embeddings,
|
||||
)?;
|
||||
let masks = if multimask_output {
|
||||
masks.i((.., 1..))?
|
||||
} else {
|
||||
masks.i((.., 0..1))?
|
||||
};
|
||||
let iou_pred = if multimask_output {
|
||||
iou_pred.i((.., 1..))?
|
||||
} else {
|
||||
iou_pred.i((.., 0..1))?
|
||||
};
|
||||
Ok((masks, iou_pred))
|
||||
}
|
||||
|
||||
fn predict_masks(
|
||||
&self,
|
||||
image_embeddings: &Tensor,
|
||||
image_pe: &Tensor,
|
||||
sparse_prompt_embeddings: &Tensor,
|
||||
dense_prompt_embeddings: &Tensor,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
// Concatenate ouput tokens.
|
||||
let output_tokens = Tensor::cat(
|
||||
&[self.iou_token.embeddings(), self.mask_tokens.embeddings()],
|
||||
0,
|
||||
)?;
|
||||
let (d1, d2) = output_tokens.dims2()?;
|
||||
let output_tokens =
|
||||
output_tokens
|
||||
.unsqueeze(0)?
|
||||
.expand((sparse_prompt_embeddings.dim(0)?, d1, d2))?;
|
||||
let tokens = Tensor::cat(&[&output_tokens, sparse_prompt_embeddings], 1)?;
|
||||
|
||||
// Expand per-image data in batch direction to be per mask
|
||||
let src = repeat_interleave(image_embeddings, tokens.dim(0)?, 0)?;
|
||||
let src = (src + dense_prompt_embeddings)?;
|
||||
let pos_src = repeat_interleave(image_pe, tokens.dim(0)?, 0)?;
|
||||
let (b, c, h, w) = src.dims4()?;
|
||||
|
||||
// Run the transformer
|
||||
let (hs, src) = run_transformer(&src, &pos_src, &tokens)?;
|
||||
let iou_token_out = hs.i((.., 0))?;
|
||||
let mask_tokens_out = hs.i((.., 1, 1 + self.num_mask_tokens))?;
|
||||
|
||||
// Upscale mask embeddings and predict masks using the masks tokens.
|
||||
let src = src.transpose(1, 2)?.reshape((b, c, h, w))?;
|
||||
let upscaled_embedding = self
|
||||
.output_upscaling_conv1
|
||||
.forward(&src)?
|
||||
.apply(&self.output_upscaling_ln)?
|
||||
.gelu()?
|
||||
.apply(&self.output_upscaling_conv2)?
|
||||
.gelu()?;
|
||||
let mut hyper_in_list = Vec::with_capacity(self.num_mask_tokens);
|
||||
for (i, mlp) in self.output_hypernetworks_mlps.iter().enumerate() {
|
||||
let h = mlp.forward(&mask_tokens_out.i((.., i))?)?;
|
||||
hyper_in_list.push(h)
|
||||
}
|
||||
let hyper_in = Tensor::stack(hyper_in_list.as_slice(), 1)?;
|
||||
let (b, c, h, w) = upscaled_embedding.dims4()?;
|
||||
let masks = hyper_in
|
||||
.matmul(&upscaled_embedding.reshape((b, c, h * w))?)?
|
||||
.reshape((b, 0, h, w))?;
|
||||
|
||||
// Generate mask quality predictions.
|
||||
let iou_pred = self.iou_prediction_head.forward(&iou_token_out)?;
|
||||
Ok((masks, iou_pred))
|
||||
}
|
||||
}
|
||||
|
||||
// Equivalent to torch.repeat_interleave
|
||||
fn repeat_interleave(_img: &Tensor, _repeats: usize, _dim: usize) -> Result<Tensor> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn run_transformer(_src: &Tensor, _pos: &Tensor, _tokens: &Tensor) -> Result<(Tensor, Tensor)> {
|
||||
todo!()
|
||||
}
|
@ -0,0 +1,77 @@
|
||||
use candle::{DType, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Attention {
|
||||
q_proj: Linear,
|
||||
k_proj: Linear,
|
||||
v_proj: Linear,
|
||||
out_proj: Linear,
|
||||
internal_dim: usize,
|
||||
num_heads: usize,
|
||||
}
|
||||
|
||||
impl Attention {
|
||||
fn new(
|
||||
embedding_dim: usize,
|
||||
num_heads: usize,
|
||||
downsample_rate: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let internal_dim = embedding_dim / downsample_rate;
|
||||
let q_proj = candle_nn::linear(embedding_dim, internal_dim, vb.pp("q_proj"))?;
|
||||
let k_proj = candle_nn::linear(embedding_dim, internal_dim, vb.pp("k_proj"))?;
|
||||
let v_proj = candle_nn::linear(embedding_dim, internal_dim, vb.pp("v_proj"))?;
|
||||
let out_proj = candle_nn::linear(internal_dim, embedding_dim, vb.pp("out_proj"))?;
|
||||
Ok(Self {
|
||||
q_proj,
|
||||
k_proj,
|
||||
v_proj,
|
||||
out_proj,
|
||||
internal_dim,
|
||||
num_heads,
|
||||
})
|
||||
}
|
||||
|
||||
fn separate_heads(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let (b, n, c) = x.dims3()?;
|
||||
x.reshape((b, n, self.num_heads, c / self.num_heads))?
|
||||
.transpose(1, 2)
|
||||
}
|
||||
|
||||
fn recombine_heads(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let (b, n_heads, n_tokens, c_per_head) = x.dims4()?;
|
||||
x.transpose(1, 2)?
|
||||
.reshape((b, n_tokens, n_heads * c_per_head))
|
||||
}
|
||||
|
||||
fn forward(&self, q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
|
||||
let q = self.q_proj.forward(q)?;
|
||||
let k = self.k_proj.forward(k)?;
|
||||
let v = self.v_proj.forward(v)?;
|
||||
|
||||
let q = self.separate_heads(&q)?;
|
||||
let k = self.separate_heads(&k)?;
|
||||
let v = self.separate_heads(&v)?;
|
||||
|
||||
let (_, _, _, c_per_head) = q.dims4()?;
|
||||
let attn = (q.matmul(&k.t()?)? / (c_per_head as f64).sqrt())?;
|
||||
let attn = candle_nn::ops::softmax_last_dim(&attn)?;
|
||||
|
||||
let out = attn.matmul(&v)?;
|
||||
self.recombine_heads(&out)?.apply(&self.out_proj)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct TwoWayAttentionBlock {
|
||||
self_attn: Attention,
|
||||
norm1: LayerNorm,
|
||||
cross_attn_token_to_image: Attention,
|
||||
norm2: LayerNorm,
|
||||
mlp: crate::MlpBlock,
|
||||
norm3: LayerNorm,
|
||||
norm4: LayerNorm,
|
||||
cross_attn_image_to_token: Attention,
|
||||
skip_first_layer_pe: bool,
|
||||
}
|
@ -130,6 +130,17 @@ pub struct ConvTranspose2dConfig {
|
||||
// TODO: support groups.
|
||||
}
|
||||
|
||||
impl Default for ConvTranspose2dConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
padding: 0,
|
||||
output_padding: 0,
|
||||
stride: 1,
|
||||
dilation: 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ConvTranspose2d {
|
||||
weight: Tensor,
|
||||
|
Reference in New Issue
Block a user