mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
Move more models to candle-transformers (#796)
* Move dinov2. * Move efficientnet. * Move the quantized llama model. * Move segment-anything.
This commit is contained in:
@ -7,108 +7,11 @@ extern crate intel_mkl_src;
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
pub mod model_image_encoder;
|
||||
pub mod model_mask_decoder;
|
||||
pub mod model_prompt_encoder;
|
||||
pub mod model_sam;
|
||||
pub mod model_tiny_vit;
|
||||
pub mod model_transformer;
|
||||
|
||||
use candle::{DType, Result, Tensor};
|
||||
use candle_nn::{Module, VarBuilder};
|
||||
use candle::DType;
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::models::segment_anything::sam;
|
||||
use clap::Parser;
|
||||
|
||||
pub fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> {
|
||||
let inner = if bias {
|
||||
candle_nn::linear(in_dim, out_dim, vb)?
|
||||
} else {
|
||||
candle_nn::linear_no_bias(in_dim, out_dim, vb)?
|
||||
};
|
||||
let span = tracing::span!(tracing::Level::TRACE, "linear");
|
||||
Ok(Linear { inner, span })
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct LayerNorm2d {
|
||||
weight: Tensor,
|
||||
bias: Tensor,
|
||||
num_channels: usize,
|
||||
eps: f64,
|
||||
}
|
||||
|
||||
impl LayerNorm2d {
|
||||
pub fn new(num_channels: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
|
||||
let weight = vb.get(num_channels, "weight")?;
|
||||
let bias = vb.get(num_channels, "bias")?;
|
||||
Ok(Self {
|
||||
weight,
|
||||
bias,
|
||||
num_channels,
|
||||
eps,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for LayerNorm2d {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let u = xs.mean_keepdim(1)?;
|
||||
let xs = xs.broadcast_sub(&u)?;
|
||||
let s = xs.sqr()?.mean_keepdim(1)?;
|
||||
let xs = xs.broadcast_div(&(s + self.eps)?.sqrt()?)?;
|
||||
xs.broadcast_mul(&self.weight.reshape((1, self.num_channels, 1, 1))?)?
|
||||
.broadcast_add(&self.bias.reshape((1, self.num_channels, 1, 1))?)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct MlpBlock {
|
||||
lin1: Linear,
|
||||
lin2: Linear,
|
||||
activation: candle_nn::Activation,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl MlpBlock {
|
||||
pub fn new(
|
||||
embedding_dim: usize,
|
||||
mlp_dim: usize,
|
||||
activation: candle_nn::Activation,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let lin1 = linear(vb.pp("lin1"), embedding_dim, mlp_dim, true)?;
|
||||
let lin2 = linear(vb.pp("lin2"), mlp_dim, embedding_dim, true)?;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "mlp-block");
|
||||
Ok(Self {
|
||||
lin1,
|
||||
lin2,
|
||||
activation,
|
||||
span,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for MlpBlock {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
xs.apply(&self.lin1)?
|
||||
.apply(&self.activation)?
|
||||
.apply(&self.lin2)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Linear {
|
||||
inner: candle_nn::Linear,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl Module for Linear {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
self.inner.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser)]
|
||||
struct Args {
|
||||
#[arg(long)]
|
||||
@ -173,7 +76,7 @@ pub fn main() -> anyhow::Result<()> {
|
||||
let (_c, h, w) = image.dims3()?;
|
||||
(image, h, w)
|
||||
} else {
|
||||
let (image, h, w) = candle_examples::load_image(&args.image, Some(model_sam::IMAGE_SIZE))?;
|
||||
let (image, h, w) = candle_examples::load_image(&args.image, Some(sam::IMAGE_SIZE))?;
|
||||
(image.to_device(&device)?, h, w)
|
||||
};
|
||||
println!("loaded image {image:?}");
|
||||
@ -195,9 +98,9 @@ pub fn main() -> anyhow::Result<()> {
|
||||
let weights = weights.deserialize()?;
|
||||
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
|
||||
let sam = if args.use_tiny {
|
||||
model_sam::Sam::new_tiny(vb)? // tiny vit_t
|
||||
sam::Sam::new_tiny(vb)? // tiny vit_t
|
||||
} else {
|
||||
model_sam::Sam::new(768, 12, 12, &[2, 5, 8, 11], vb)? // sam_vit_b
|
||||
sam::Sam::new(768, 12, 12, &[2, 5, 8, 11], vb)? // sam_vit_b
|
||||
};
|
||||
|
||||
if args.generate_masks {
|
||||
|
@ -1,483 +0,0 @@
|
||||
use candle::{DType, IndexOp, Result, Tensor};
|
||||
use candle_nn::{layer_norm, LayerNorm, Module, VarBuilder};
|
||||
|
||||
#[derive(Debug)]
|
||||
struct PatchEmbed {
|
||||
proj: candle_nn::Conv2d,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl PatchEmbed {
|
||||
fn new(
|
||||
in_chans: usize,
|
||||
embed_dim: usize,
|
||||
k_size: usize,
|
||||
stride: usize,
|
||||
padding: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let cfg = candle_nn::Conv2dConfig {
|
||||
stride,
|
||||
padding,
|
||||
..Default::default()
|
||||
};
|
||||
let proj = candle_nn::conv2d(in_chans, embed_dim, k_size, cfg, vb.pp("proj"))?;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "patch-embed");
|
||||
Ok(Self { proj, span })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for PatchEmbed {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
xs.apply(&self.proj)?.permute((0, 2, 3, 1))
|
||||
}
|
||||
}
|
||||
|
||||
// A custom op to make add_decomposed_rel_pos faster. Most of the time is spent on the final
|
||||
// addition in the case where b = 12, q_h = q_w = 4096, k_h = k_w = 4096
|
||||
// (attn.reshape((b, q_h, q_w, k_h, k_w))?
|
||||
// + rel_h.unsqueeze(4)?.broadcast_add(&rel_w.unsqueeze(3)?)?)?
|
||||
// .reshape((b, q_h * q_w, k_h * k_w))
|
||||
// Ideally we would perform this operation in place but this is not supported in candle at the
|
||||
// moment. We should also investigate using f16 rather than f32.
|
||||
struct Add3(usize, usize, usize, usize, usize);
|
||||
impl candle::CustomOp3 for Add3 {
|
||||
fn name(&self) -> &'static str {
|
||||
"add3"
|
||||
}
|
||||
|
||||
fn cpu_fwd(
|
||||
&self,
|
||||
s1: &candle::CpuStorage,
|
||||
l1: &candle::Layout,
|
||||
s2: &candle::CpuStorage,
|
||||
l2: &candle::Layout,
|
||||
s3: &candle::CpuStorage,
|
||||
l3: &candle::Layout,
|
||||
) -> Result<(candle::CpuStorage, candle::Shape)> {
|
||||
use rayon::prelude::*;
|
||||
|
||||
let Add3(b, q_h, q_w, k_h, k_w) = *self;
|
||||
let s1 = s1.as_slice::<f32>()?;
|
||||
let s1 = match l1.contiguous_offsets() {
|
||||
None => candle::bail!("input1 has to be contiguous"),
|
||||
Some((o1, o2)) => &s1[o1..o2],
|
||||
};
|
||||
let s2 = s2.as_slice::<f32>()?;
|
||||
let s2 = match l2.contiguous_offsets() {
|
||||
None => candle::bail!("input2 has to be contiguous"),
|
||||
Some((o1, o2)) => &s2[o1..o2],
|
||||
};
|
||||
let s3 = s3.as_slice::<f32>()?;
|
||||
let s3 = match l3.contiguous_offsets() {
|
||||
None => candle::bail!("input3 has to be contiguous"),
|
||||
Some((o1, o2)) => &s3[o1..o2],
|
||||
};
|
||||
let mut dst = vec![0f32; b * q_h * q_w * k_h * k_w];
|
||||
dst.par_chunks_exact_mut(k_h * k_w)
|
||||
.enumerate()
|
||||
.for_each(|(b_idx, dst)| {
|
||||
let s1_idx = b_idx * k_h * k_w;
|
||||
let s2_idx = b_idx * k_h;
|
||||
let s3_idx = b_idx * k_w;
|
||||
for h_idx in 0..k_h {
|
||||
let s1_idx = s1_idx + h_idx * k_w;
|
||||
let s2_idx = s2_idx + h_idx;
|
||||
let dst_idx = h_idx * k_w;
|
||||
for w_idx in 0..k_w {
|
||||
let s1_idx = s1_idx + w_idx;
|
||||
let s3_idx = s3_idx + w_idx;
|
||||
let dst_idx = dst_idx + w_idx;
|
||||
dst[dst_idx] = s1[s1_idx] + s2[s2_idx] + s3[s3_idx]
|
||||
}
|
||||
}
|
||||
});
|
||||
let dst = candle::WithDType::to_cpu_storage_owned(dst);
|
||||
Ok((dst, (b, q_h * q_w, k_h * k_w).into()))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Attention {
|
||||
qkv: crate::Linear,
|
||||
proj: crate::Linear,
|
||||
num_heads: usize,
|
||||
scale: f64,
|
||||
rel_pos_hw: Option<(Tensor, Tensor)>,
|
||||
span: tracing::Span,
|
||||
span_matmul: tracing::Span,
|
||||
span_rel_pos: tracing::Span,
|
||||
span_softmax: tracing::Span,
|
||||
}
|
||||
|
||||
impl Attention {
|
||||
fn new(
|
||||
dim: usize,
|
||||
num_heads: usize,
|
||||
qkv_bias: bool,
|
||||
use_rel_pos: bool,
|
||||
input_size: (usize, usize),
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "attention");
|
||||
let span_matmul = tracing::span!(tracing::Level::TRACE, "attn-matmul");
|
||||
let span_rel_pos = tracing::span!(tracing::Level::TRACE, "attn-rel-pos");
|
||||
let span_softmax = tracing::span!(tracing::Level::TRACE, "attn-sm");
|
||||
let qkv = crate::linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?;
|
||||
let proj = crate::linear(vb.pp("proj"), dim, dim, true)?;
|
||||
let head_dim = dim / num_heads;
|
||||
let scale = 1. / (head_dim as f64).sqrt();
|
||||
let rel_pos_hw = if use_rel_pos {
|
||||
let h = vb.get((2 * input_size.0 - 1, head_dim), "rel_pos_h")?;
|
||||
let w = vb.get((2 * input_size.1 - 1, head_dim), "rel_pos_w")?;
|
||||
Some((h, w))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(Self {
|
||||
qkv,
|
||||
proj,
|
||||
num_heads,
|
||||
scale,
|
||||
rel_pos_hw,
|
||||
span,
|
||||
span_matmul,
|
||||
span_rel_pos,
|
||||
span_softmax,
|
||||
})
|
||||
}
|
||||
|
||||
fn add_decomposed_rel_pos(
|
||||
&self,
|
||||
attn: Tensor,
|
||||
q: &Tensor,
|
||||
(q_h, q_w): (usize, usize),
|
||||
(k_h, k_w): (usize, usize),
|
||||
) -> Result<Tensor> {
|
||||
match &self.rel_pos_hw {
|
||||
Some((rel_pos_h, rel_pos_w)) => {
|
||||
let r_h = get_rel_pos(q_h, k_h, rel_pos_h)?;
|
||||
let r_w = get_rel_pos(q_w, k_w, rel_pos_w)?;
|
||||
let (b, _, dim) = q.dims3()?;
|
||||
let r_q = q.reshape((b, q_h, q_w, dim))?;
|
||||
// rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
|
||||
let rel_h = r_q.matmul(&r_h.broadcast_left(b)?.t()?.contiguous()?)?;
|
||||
// rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
|
||||
let rel_w = r_q
|
||||
.transpose(1, 2)? // -> bwhc
|
||||
.contiguous()?
|
||||
.matmul(&r_w.broadcast_left(b)?.t()?.contiguous()?)? // bwhc,bwck -> bwhk
|
||||
.transpose(1, 2)?
|
||||
.contiguous()?;
|
||||
if attn.device().is_cpu() {
|
||||
let op = Add3(b, q_h, q_w, k_h, k_w);
|
||||
attn.apply_op3_no_bwd(&rel_h, &rel_w, &op)
|
||||
} else {
|
||||
(attn.reshape((b, q_h, q_w, k_h, k_w))?
|
||||
+ rel_h.unsqueeze(4)?.broadcast_add(&rel_w.unsqueeze(3)?)?)?
|
||||
.reshape((b, q_h * q_w, k_h * k_w))
|
||||
}
|
||||
}
|
||||
None => Ok(attn),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn get_rel_pos(q_size: usize, k_size: usize, rel_pos: &Tensor) -> Result<Tensor> {
|
||||
let max_rel_dist = 2 * usize::max(q_size, k_size) - 1;
|
||||
let dev = rel_pos.device();
|
||||
let rel_pos_resized = if rel_pos.dim(0)? != max_rel_dist {
|
||||
todo!("interpolation")
|
||||
} else {
|
||||
rel_pos
|
||||
};
|
||||
let q_coords = Tensor::arange(0u32, q_size as u32, dev)?
|
||||
.reshape((q_size, 1))?
|
||||
.to_dtype(DType::F32)?;
|
||||
let k_coords = Tensor::arange(0u32, k_size as u32, dev)?
|
||||
.reshape((1, k_size))?
|
||||
.to_dtype(DType::F32)?;
|
||||
let q_coords = (q_coords * f64::max(1f64, k_size as f64 / q_size as f64))?;
|
||||
let k_coords = (k_coords * f64::max(1f64, q_size as f64 / k_size as f64))?;
|
||||
let relative_coords = (q_coords.broadcast_sub(&k_coords)?
|
||||
+ (k_size as f64 - 1.) * f64::max(1f64, q_size as f64 / k_size as f64))?;
|
||||
let (d1, d2) = relative_coords.dims2()?;
|
||||
let relative_coords = relative_coords.to_dtype(DType::U32)?;
|
||||
rel_pos_resized
|
||||
.index_select(&relative_coords.reshape(d1 * d2)?, 0)?
|
||||
.reshape((d1, d2, ()))
|
||||
}
|
||||
|
||||
impl Module for Attention {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let (b, h, w, c) = xs.dims4()?;
|
||||
let qkv = self
|
||||
.qkv
|
||||
.forward(&xs.flatten_to(1)?)?
|
||||
.reshape((b, h * w, 3, self.num_heads, c / self.num_heads))?
|
||||
.permute((2, 0, 3, 1, 4))?
|
||||
.reshape((3, b * self.num_heads, h * w, c / self.num_heads))?;
|
||||
let q = qkv.i(0)?;
|
||||
let k = qkv.i(1)?;
|
||||
let v = qkv.i(2)?;
|
||||
let attn = {
|
||||
let _enter = self.span_matmul.enter();
|
||||
(&q * self.scale)?.matmul(&k.t()?)?
|
||||
};
|
||||
let attn = {
|
||||
let _enter = self.span_rel_pos.enter();
|
||||
self.add_decomposed_rel_pos(attn, &q, (h, w), (h, w))?
|
||||
};
|
||||
let attn = {
|
||||
let _enter = self.span_softmax.enter();
|
||||
candle_nn::ops::softmax_last_dim(&attn)?
|
||||
};
|
||||
let attn = {
|
||||
let _enter = self.span_matmul.enter();
|
||||
attn.matmul(&v)?
|
||||
};
|
||||
let attn = attn
|
||||
.reshape((b, self.num_heads, h, w, c / self.num_heads))?
|
||||
.permute((0, 2, 3, 1, 4))?
|
||||
.reshape((b, h * w, c))?;
|
||||
self.proj.forward(&attn)?.reshape((b, h, w, c))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Block {
|
||||
norm1: LayerNorm,
|
||||
attn: Attention,
|
||||
norm2: LayerNorm,
|
||||
mlp: crate::MlpBlock,
|
||||
window_size: usize,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl Block {
|
||||
fn new(
|
||||
dim: usize,
|
||||
num_heads: usize,
|
||||
qkv_bias: bool,
|
||||
use_rel_pos: bool,
|
||||
window_size: usize,
|
||||
input_size: (usize, usize),
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let norm1 = layer_norm(dim, 1e-6, vb.pp("norm1"))?;
|
||||
let norm2 = layer_norm(dim, 1e-6, vb.pp("norm2"))?;
|
||||
let input_size_attn = if window_size == 0 {
|
||||
input_size
|
||||
} else {
|
||||
(window_size, window_size)
|
||||
};
|
||||
let attn = Attention::new(
|
||||
dim,
|
||||
num_heads,
|
||||
qkv_bias,
|
||||
use_rel_pos,
|
||||
input_size_attn,
|
||||
vb.pp("attn"),
|
||||
)?;
|
||||
let mlp = crate::MlpBlock::new(dim, dim * 4, candle_nn::Activation::Gelu, vb.pp("mlp"))?;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "ie-block");
|
||||
Ok(Self {
|
||||
norm1,
|
||||
attn,
|
||||
norm2,
|
||||
mlp,
|
||||
window_size,
|
||||
span,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn window_partition(xs: Tensor, window_size: usize) -> Result<(Tensor, (usize, usize))> {
|
||||
let (b, h, w, c) = xs.dims4()?;
|
||||
let pad_h = (window_size - h % window_size) % window_size;
|
||||
let pad_w = (window_size - w % window_size) % window_size;
|
||||
let xs = if pad_h > 0 {
|
||||
xs.pad_with_zeros(1, 0, pad_h)?
|
||||
} else {
|
||||
xs
|
||||
};
|
||||
let xs = if pad_w > 0 {
|
||||
xs.pad_with_zeros(2, 0, pad_w)?
|
||||
} else {
|
||||
xs
|
||||
};
|
||||
let (h_p, w_p) = (h + pad_h, w + pad_w);
|
||||
let windows = xs
|
||||
.reshape((
|
||||
b,
|
||||
h_p / window_size,
|
||||
window_size,
|
||||
w_p / window_size,
|
||||
window_size,
|
||||
c,
|
||||
))?
|
||||
.transpose(2, 3)?
|
||||
.contiguous()?
|
||||
.flatten_to(2)?;
|
||||
Ok((windows, (h_p, w_p)))
|
||||
}
|
||||
|
||||
fn window_unpartition(
|
||||
windows: Tensor,
|
||||
window_size: usize,
|
||||
(h_p, w_p): (usize, usize),
|
||||
(h, w): (usize, usize),
|
||||
) -> Result<Tensor> {
|
||||
let b = windows.dim(0)? / (h_p * w_p / window_size / window_size);
|
||||
let xs = windows
|
||||
.reshape((
|
||||
b,
|
||||
h_p / window_size,
|
||||
w_p / window_size,
|
||||
window_size,
|
||||
window_size,
|
||||
windows.elem_count() / b / h_p / w_p,
|
||||
))?
|
||||
.transpose(2, 3)?
|
||||
.contiguous()?
|
||||
.reshape((b, h_p, w_p, ()))?;
|
||||
let xs = if h_p > h { xs.narrow(1, 0, h)? } else { xs };
|
||||
let xs = if w_p > w { xs.narrow(2, 0, w)? } else { xs };
|
||||
Ok(xs)
|
||||
}
|
||||
|
||||
impl Module for Block {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let shortcut = xs;
|
||||
let xs = self.norm1.forward(xs)?;
|
||||
let hw = (xs.dim(1)?, xs.dim(2)?);
|
||||
let (xs, pad_hw) = if self.window_size > 0 {
|
||||
window_partition(xs, self.window_size)?
|
||||
} else {
|
||||
(xs, (0, 0))
|
||||
};
|
||||
let xs = self.attn.forward(&xs)?;
|
||||
let xs = if self.window_size > 0 {
|
||||
window_unpartition(xs, self.window_size, pad_hw, hw)?
|
||||
} else {
|
||||
xs
|
||||
};
|
||||
let xs = (xs + shortcut)?;
|
||||
&xs + xs.apply(&self.norm2)?.apply(&self.mlp)?
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ImageEncoderViT {
|
||||
patch_embed: PatchEmbed,
|
||||
blocks: Vec<Block>,
|
||||
neck_conv1: candle_nn::Conv2d,
|
||||
neck_ln1: crate::LayerNorm2d,
|
||||
neck_conv2: candle_nn::Conv2d,
|
||||
neck_ln2: crate::LayerNorm2d,
|
||||
pos_embed: Option<Tensor>,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl ImageEncoderViT {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn new(
|
||||
img_size: usize,
|
||||
patch_size: usize,
|
||||
in_chans: usize,
|
||||
embed_dim: usize,
|
||||
depth: usize,
|
||||
num_heads: usize,
|
||||
out_chans: usize,
|
||||
qkv_bias: bool,
|
||||
use_rel_pos: bool,
|
||||
use_abs_pos: bool,
|
||||
window_size: usize,
|
||||
global_attn_indexes: &[usize],
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let patch_embed = PatchEmbed::new(
|
||||
in_chans,
|
||||
embed_dim,
|
||||
patch_size,
|
||||
patch_size,
|
||||
0,
|
||||
vb.pp("patch_embed"),
|
||||
)?;
|
||||
let mut blocks = Vec::with_capacity(depth);
|
||||
let vb_b = vb.pp("blocks");
|
||||
for i in 0..depth {
|
||||
let window_size = if global_attn_indexes.contains(&i) {
|
||||
0
|
||||
} else {
|
||||
window_size
|
||||
};
|
||||
let block = Block::new(
|
||||
embed_dim,
|
||||
num_heads,
|
||||
qkv_bias,
|
||||
use_rel_pos,
|
||||
window_size,
|
||||
(img_size / patch_size, img_size / patch_size),
|
||||
vb_b.pp(i),
|
||||
)?;
|
||||
blocks.push(block)
|
||||
}
|
||||
let neck_conv1 = candle_nn::conv2d_no_bias(
|
||||
embed_dim,
|
||||
out_chans,
|
||||
1,
|
||||
Default::default(),
|
||||
vb.pp("neck.0"),
|
||||
)?;
|
||||
let neck_ln1 = crate::LayerNorm2d::new(out_chans, 1e-6, vb.pp("neck.1"))?;
|
||||
let cfg = candle_nn::Conv2dConfig {
|
||||
padding: 1,
|
||||
..Default::default()
|
||||
};
|
||||
let neck_conv2 = candle_nn::conv2d_no_bias(out_chans, out_chans, 3, cfg, vb.pp("neck.2"))?;
|
||||
let neck_ln2 = crate::LayerNorm2d::new(out_chans, 1e-6, vb.pp("neck.3"))?;
|
||||
let pos_embed = if use_abs_pos {
|
||||
let p = vb.get(
|
||||
(1, img_size / patch_size, img_size / patch_size, embed_dim),
|
||||
"pos_embed",
|
||||
)?;
|
||||
Some(p)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let span = tracing::span!(tracing::Level::TRACE, "image-encoder-vit");
|
||||
Ok(Self {
|
||||
patch_embed,
|
||||
blocks,
|
||||
neck_conv1,
|
||||
neck_ln1,
|
||||
neck_conv2,
|
||||
neck_ln2,
|
||||
pos_embed,
|
||||
span,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for ImageEncoderViT {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let xs = self.patch_embed.forward(xs)?;
|
||||
let mut xs = match &self.pos_embed {
|
||||
Some(pos_embed) => (xs + pos_embed)?,
|
||||
None => xs,
|
||||
};
|
||||
for block in self.blocks.iter() {
|
||||
xs = block.forward(&xs)?
|
||||
}
|
||||
xs.permute((0, 3, 1, 2))?
|
||||
.apply(&self.neck_conv1)?
|
||||
.apply(&self.neck_ln1)?
|
||||
.apply(&self.neck_conv2)?
|
||||
.apply(&self.neck_ln2)
|
||||
}
|
||||
}
|
@ -1,239 +0,0 @@
|
||||
use candle::{IndexOp, Result, Tensor};
|
||||
use candle_nn::{Module, VarBuilder};
|
||||
|
||||
use crate::model_transformer::TwoWayTransformer;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct MlpMaskDecoder {
|
||||
layers: Vec<crate::Linear>,
|
||||
sigmoid_output: bool,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl MlpMaskDecoder {
|
||||
fn new(
|
||||
input_dim: usize,
|
||||
hidden_dim: usize,
|
||||
output_dim: usize,
|
||||
num_layers: usize,
|
||||
sigmoid_output: bool,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let mut layers = Vec::with_capacity(num_layers);
|
||||
let vb = vb.pp("layers");
|
||||
for i in 0..num_layers {
|
||||
let in_dim = if i == 0 { input_dim } else { hidden_dim };
|
||||
let out_dim = if i + 1 == num_layers {
|
||||
output_dim
|
||||
} else {
|
||||
hidden_dim
|
||||
};
|
||||
let layer = crate::linear(vb.pp(i), in_dim, out_dim, true)?;
|
||||
layers.push(layer)
|
||||
}
|
||||
let span = tracing::span!(tracing::Level::TRACE, "mlp-mask-decoder");
|
||||
Ok(Self {
|
||||
layers,
|
||||
sigmoid_output,
|
||||
span,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for MlpMaskDecoder {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let mut xs = xs.clone();
|
||||
for (i, layer) in self.layers.iter().enumerate() {
|
||||
xs = layer.forward(&xs)?;
|
||||
if i + 1 < self.layers.len() {
|
||||
xs = xs.relu()?
|
||||
}
|
||||
}
|
||||
if self.sigmoid_output {
|
||||
candle_nn::ops::sigmoid(&xs)
|
||||
} else {
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct MaskDecoder {
|
||||
iou_token: candle_nn::Embedding,
|
||||
mask_tokens: candle_nn::Embedding,
|
||||
iou_prediction_head: MlpMaskDecoder,
|
||||
output_upscaling_conv1: candle_nn::ConvTranspose2d,
|
||||
output_upscaling_ln: crate::LayerNorm2d,
|
||||
output_upscaling_conv2: candle_nn::ConvTranspose2d,
|
||||
num_mask_tokens: usize,
|
||||
output_hypernetworks_mlps: Vec<MlpMaskDecoder>,
|
||||
transformer: TwoWayTransformer,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl MaskDecoder {
|
||||
pub fn new(
|
||||
transformer_dim: usize,
|
||||
num_multimask_outputs: usize,
|
||||
iou_head_depth: usize,
|
||||
iou_head_hidden_dim: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let num_mask_tokens = num_multimask_outputs + 1;
|
||||
let iou_prediction_head = MlpMaskDecoder::new(
|
||||
transformer_dim,
|
||||
iou_head_hidden_dim,
|
||||
num_mask_tokens,
|
||||
iou_head_depth,
|
||||
false,
|
||||
vb.pp("iou_prediction_head"),
|
||||
)?;
|
||||
let iou_token = candle_nn::embedding(1, transformer_dim, vb.pp("iou_token"))?;
|
||||
let mask_tokens =
|
||||
candle_nn::embedding(num_mask_tokens, transformer_dim, vb.pp("mask_tokens"))?;
|
||||
let cfg = candle_nn::ConvTranspose2dConfig {
|
||||
stride: 2,
|
||||
..Default::default()
|
||||
};
|
||||
let output_upscaling_conv1 = candle_nn::conv_transpose2d(
|
||||
transformer_dim,
|
||||
transformer_dim / 4,
|
||||
2,
|
||||
cfg,
|
||||
vb.pp("output_upscaling.0"),
|
||||
)?;
|
||||
let output_upscaling_ln =
|
||||
crate::LayerNorm2d::new(transformer_dim / 4, 1e-6, vb.pp("output_upscaling.1"))?;
|
||||
let output_upscaling_conv2 = candle_nn::conv_transpose2d(
|
||||
transformer_dim / 4,
|
||||
transformer_dim / 8,
|
||||
2,
|
||||
cfg,
|
||||
vb.pp("output_upscaling.3"),
|
||||
)?;
|
||||
let mut output_hypernetworks_mlps = Vec::with_capacity(num_mask_tokens);
|
||||
let vb_o = vb.pp("output_hypernetworks_mlps");
|
||||
for i in 0..num_mask_tokens {
|
||||
let mlp = MlpMaskDecoder::new(
|
||||
transformer_dim,
|
||||
transformer_dim,
|
||||
transformer_dim / 8,
|
||||
3,
|
||||
false,
|
||||
vb_o.pp(i),
|
||||
)?;
|
||||
output_hypernetworks_mlps.push(mlp)
|
||||
}
|
||||
let transformer = TwoWayTransformer::new(
|
||||
/* depth */ 2,
|
||||
/* embedding_dim */ transformer_dim,
|
||||
/* num_heads */ 8,
|
||||
/* mlp_dim */ 2048,
|
||||
vb.pp("transformer"),
|
||||
)?;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "mask-decoder");
|
||||
Ok(Self {
|
||||
iou_token,
|
||||
mask_tokens,
|
||||
iou_prediction_head,
|
||||
output_upscaling_conv1,
|
||||
output_upscaling_ln,
|
||||
output_upscaling_conv2,
|
||||
num_mask_tokens,
|
||||
output_hypernetworks_mlps,
|
||||
transformer,
|
||||
span,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(
|
||||
&self,
|
||||
image_embeddings: &Tensor,
|
||||
image_pe: &Tensor,
|
||||
sparse_prompt_embeddings: &Tensor,
|
||||
dense_prompt_embeddings: &Tensor,
|
||||
multimask_output: bool,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
let _enter = self.span.enter();
|
||||
let (masks, iou_pred) = self.predict_masks(
|
||||
image_embeddings,
|
||||
image_pe,
|
||||
sparse_prompt_embeddings,
|
||||
dense_prompt_embeddings,
|
||||
)?;
|
||||
let masks = if multimask_output {
|
||||
masks.i((.., 1..))?
|
||||
} else {
|
||||
masks.i((.., 0..1))?
|
||||
};
|
||||
let iou_pred = if multimask_output {
|
||||
iou_pred.i((.., 1..))?
|
||||
} else {
|
||||
iou_pred.i((.., 0..1))?
|
||||
};
|
||||
Ok((masks, iou_pred))
|
||||
}
|
||||
|
||||
fn predict_masks(
|
||||
&self,
|
||||
image_embeddings: &Tensor,
|
||||
image_pe: &Tensor,
|
||||
sparse_prompt_embeddings: &Tensor,
|
||||
dense_prompt_embeddings: &Tensor,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
// Concatenate ouput tokens.
|
||||
let output_tokens = Tensor::cat(
|
||||
&[self.iou_token.embeddings(), self.mask_tokens.embeddings()],
|
||||
0,
|
||||
)?;
|
||||
let (d1, d2) = output_tokens.dims2()?;
|
||||
let output_tokens =
|
||||
output_tokens
|
||||
.unsqueeze(0)?
|
||||
.expand((sparse_prompt_embeddings.dim(0)?, d1, d2))?;
|
||||
let tokens = Tensor::cat(&[&output_tokens, sparse_prompt_embeddings], 1)?;
|
||||
|
||||
// Expand per-image data in batch direction to be per mask
|
||||
let src = repeat_interleave(image_embeddings, tokens.dim(0)?, 0)?;
|
||||
let src = src.broadcast_add(dense_prompt_embeddings)?;
|
||||
let pos_src = repeat_interleave(image_pe, tokens.dim(0)?, 0)?;
|
||||
let (b, c, h, w) = src.dims4()?;
|
||||
|
||||
// Run the transformer
|
||||
let (hs, src) = self.transformer.forward(&src, &pos_src, &tokens)?;
|
||||
let iou_token_out = hs.i((.., 0))?;
|
||||
let mask_tokens_out = hs.i((.., 1..1 + self.num_mask_tokens))?;
|
||||
|
||||
// Upscale mask embeddings and predict masks using the masks tokens.
|
||||
let src = src.transpose(1, 2)?.reshape((b, c, h, w))?;
|
||||
let upscaled_embedding = self
|
||||
.output_upscaling_conv1
|
||||
.forward(&src)?
|
||||
.apply(&self.output_upscaling_ln)?
|
||||
.gelu()?
|
||||
.apply(&self.output_upscaling_conv2)?
|
||||
.gelu()?;
|
||||
let mut hyper_in_list = Vec::with_capacity(self.num_mask_tokens);
|
||||
for (i, mlp) in self.output_hypernetworks_mlps.iter().enumerate() {
|
||||
let h = mlp.forward(&mask_tokens_out.i((.., i))?)?;
|
||||
hyper_in_list.push(h)
|
||||
}
|
||||
let hyper_in = Tensor::stack(hyper_in_list.as_slice(), 1)?.contiguous()?;
|
||||
let (b, c, h, w) = upscaled_embedding.dims4()?;
|
||||
let masks = hyper_in.matmul(&upscaled_embedding.reshape((b, c, h * w))?)?;
|
||||
let masks = masks.reshape((b, (), h, w))?;
|
||||
|
||||
// Generate mask quality predictions.
|
||||
let iou_pred = self.iou_prediction_head.forward(&iou_token_out)?;
|
||||
Ok((masks, iou_pred))
|
||||
}
|
||||
}
|
||||
|
||||
// Equivalent to torch.repeat_interleave
|
||||
fn repeat_interleave(img: &Tensor, repeats: usize, dim: usize) -> Result<Tensor> {
|
||||
let img = img.unsqueeze(dim + 1)?;
|
||||
let mut dims = img.dims().to_vec();
|
||||
dims[dim + 1] = repeats;
|
||||
img.broadcast_as(dims)?.flatten(dim, dim + 1)
|
||||
}
|
@ -1,239 +0,0 @@
|
||||
use candle::{DType, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::VarBuilder;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct PostionEmbeddingRandom {
|
||||
positional_encoding_gaussian_matrix: Tensor,
|
||||
}
|
||||
|
||||
impl PostionEmbeddingRandom {
|
||||
fn new(num_pos_feats: usize, vb: VarBuilder) -> Result<Self> {
|
||||
let positional_encoding_gaussian_matrix =
|
||||
vb.get((2, num_pos_feats), "positional_encoding_gaussian_matrix")?;
|
||||
Ok(Self {
|
||||
positional_encoding_gaussian_matrix,
|
||||
})
|
||||
}
|
||||
|
||||
fn pe_encoding(&self, coords: &Tensor) -> Result<Tensor> {
|
||||
let coords = coords.affine(2., -1.)?;
|
||||
let coords = coords.broadcast_matmul(&self.positional_encoding_gaussian_matrix)?;
|
||||
let coords = (coords * (2. * std::f64::consts::PI))?;
|
||||
Tensor::cat(&[coords.sin()?, coords.cos()?], D::Minus1)
|
||||
}
|
||||
|
||||
fn forward(&self, h: usize, w: usize) -> Result<Tensor> {
|
||||
let device = self.positional_encoding_gaussian_matrix.device();
|
||||
let x_embed = (Tensor::arange(0u32, w as u32, device)?.to_dtype(DType::F32)? + 0.5)?;
|
||||
let y_embed = (Tensor::arange(0u32, h as u32, device)?.to_dtype(DType::F32)? + 0.5)?;
|
||||
let x_embed = (x_embed / w as f64)?
|
||||
.reshape((1, ()))?
|
||||
.broadcast_as((h, w))?;
|
||||
let y_embed = (y_embed / h as f64)?
|
||||
.reshape(((), 1))?
|
||||
.broadcast_as((h, w))?;
|
||||
let coords = Tensor::stack(&[&x_embed, &y_embed], D::Minus1)?;
|
||||
self.pe_encoding(&coords)?.permute((2, 0, 1))
|
||||
}
|
||||
|
||||
fn forward_with_coords(
|
||||
&self,
|
||||
coords_input: &Tensor,
|
||||
image_size: (usize, usize),
|
||||
) -> Result<Tensor> {
|
||||
let coords0 = (coords_input.narrow(D::Minus1, 0, 1)? / image_size.1 as f64)?;
|
||||
let coords1 = (coords_input.narrow(D::Minus1, 1, 1)? / image_size.0 as f64)?;
|
||||
let c = coords_input.dim(D::Minus1)?;
|
||||
let coords_rest = coords_input.narrow(D::Minus1, 2, c - 2)?;
|
||||
let coords = Tensor::cat(&[&coords0, &coords1, &coords_rest], D::Minus1)?;
|
||||
self.pe_encoding(&coords)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct PromptEncoder {
|
||||
pe_layer: PostionEmbeddingRandom,
|
||||
point_embeddings: Vec<candle_nn::Embedding>,
|
||||
not_a_point_embed: candle_nn::Embedding,
|
||||
mask_downscaling_conv1: candle_nn::Conv2d,
|
||||
mask_downscaling_ln1: crate::LayerNorm2d,
|
||||
mask_downscaling_conv2: candle_nn::Conv2d,
|
||||
mask_downscaling_ln2: crate::LayerNorm2d,
|
||||
mask_downscaling_conv3: candle_nn::Conv2d,
|
||||
no_mask_embed: candle_nn::Embedding,
|
||||
image_embedding_size: (usize, usize),
|
||||
input_image_size: (usize, usize),
|
||||
embed_dim: usize,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl PromptEncoder {
|
||||
pub fn new(
|
||||
embed_dim: usize,
|
||||
image_embedding_size: (usize, usize),
|
||||
input_image_size: (usize, usize),
|
||||
mask_in_chans: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let num_points_embeddings = 4;
|
||||
let pe_layer = PostionEmbeddingRandom::new(embed_dim / 2, vb.pp("pe_layer"))?;
|
||||
let not_a_point_embed = candle_nn::embedding(1, embed_dim, vb.pp("not_a_point_embed"))?;
|
||||
let no_mask_embed = candle_nn::embedding(1, embed_dim, vb.pp("no_mask_embed"))?;
|
||||
let cfg = candle_nn::Conv2dConfig {
|
||||
stride: 2,
|
||||
..Default::default()
|
||||
};
|
||||
let mask_downscaling_conv1 =
|
||||
candle_nn::conv2d(1, mask_in_chans / 4, 2, cfg, vb.pp("mask_downscaling.0"))?;
|
||||
let mask_downscaling_conv2 = candle_nn::conv2d(
|
||||
mask_in_chans / 4,
|
||||
mask_in_chans,
|
||||
2,
|
||||
cfg,
|
||||
vb.pp("mask_downscaling.3"),
|
||||
)?;
|
||||
let mask_downscaling_conv3 = candle_nn::conv2d(
|
||||
mask_in_chans,
|
||||
embed_dim,
|
||||
1,
|
||||
Default::default(),
|
||||
vb.pp("mask_downscaling.6"),
|
||||
)?;
|
||||
let mask_downscaling_ln1 =
|
||||
crate::LayerNorm2d::new(mask_in_chans / 4, 1e-6, vb.pp("mask_downscaling.1"))?;
|
||||
let mask_downscaling_ln2 =
|
||||
crate::LayerNorm2d::new(mask_in_chans, 1e-6, vb.pp("mask_downscaling.4"))?;
|
||||
let mut point_embeddings = Vec::with_capacity(num_points_embeddings);
|
||||
let vb_e = vb.pp("point_embeddings");
|
||||
for i in 0..num_points_embeddings {
|
||||
let emb = candle_nn::embedding(1, embed_dim, vb_e.pp(i))?;
|
||||
point_embeddings.push(emb)
|
||||
}
|
||||
let span = tracing::span!(tracing::Level::TRACE, "prompt-encoder");
|
||||
Ok(Self {
|
||||
pe_layer,
|
||||
point_embeddings,
|
||||
not_a_point_embed,
|
||||
mask_downscaling_conv1,
|
||||
mask_downscaling_ln1,
|
||||
mask_downscaling_conv2,
|
||||
mask_downscaling_ln2,
|
||||
mask_downscaling_conv3,
|
||||
no_mask_embed,
|
||||
image_embedding_size,
|
||||
input_image_size,
|
||||
embed_dim,
|
||||
span,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn get_dense_pe(&self) -> Result<Tensor> {
|
||||
self.pe_layer
|
||||
.forward(self.image_embedding_size.0, self.image_embedding_size.1)?
|
||||
.unsqueeze(0)
|
||||
}
|
||||
|
||||
fn embed_masks(&self, masks: &Tensor) -> Result<Tensor> {
|
||||
masks
|
||||
.apply(&self.mask_downscaling_conv1)?
|
||||
.apply(&self.mask_downscaling_ln1)?
|
||||
.gelu()?
|
||||
.apply(&self.mask_downscaling_conv2)?
|
||||
.apply(&self.mask_downscaling_ln2)?
|
||||
.gelu()?
|
||||
.apply(&self.mask_downscaling_conv3)
|
||||
}
|
||||
|
||||
fn embed_points(&self, points: &Tensor, labels: &Tensor, pad: bool) -> Result<Tensor> {
|
||||
let points = (points + 0.5)?;
|
||||
let dev = points.device();
|
||||
let (points, labels) = if pad {
|
||||
let padding_point = Tensor::zeros((points.dim(0)?, 1, 2), DType::F32, dev)?;
|
||||
let padding_label = (Tensor::ones((labels.dim(0)?, 1), DType::F32, dev)? * (-1f64))?;
|
||||
let points = Tensor::cat(&[&points, &padding_point], 1)?;
|
||||
let labels = Tensor::cat(&[labels, &padding_label], 1)?;
|
||||
(points, labels)
|
||||
} else {
|
||||
(points, labels.clone())
|
||||
};
|
||||
let point_embedding = self
|
||||
.pe_layer
|
||||
.forward_with_coords(&points, self.input_image_size)?;
|
||||
let labels = labels.unsqueeze(2)?.broadcast_as(point_embedding.shape())?;
|
||||
let zeros = point_embedding.zeros_like()?;
|
||||
let point_embedding = labels.lt(0f32)?.where_cond(
|
||||
&self
|
||||
.not_a_point_embed
|
||||
.embeddings()
|
||||
.broadcast_as(zeros.shape())?,
|
||||
&point_embedding,
|
||||
)?;
|
||||
let labels0 = labels.eq(0f32)?.where_cond(
|
||||
&self.point_embeddings[0]
|
||||
.embeddings()
|
||||
.broadcast_as(zeros.shape())?,
|
||||
&zeros,
|
||||
)?;
|
||||
let point_embedding = (point_embedding + labels0)?;
|
||||
let labels1 = labels.eq(1f32)?.where_cond(
|
||||
&self.point_embeddings[1]
|
||||
.embeddings()
|
||||
.broadcast_as(zeros.shape())?,
|
||||
&zeros,
|
||||
)?;
|
||||
let point_embedding = (point_embedding + labels1)?;
|
||||
Ok(point_embedding)
|
||||
}
|
||||
|
||||
fn embed_boxes(&self, boxes: &Tensor) -> Result<Tensor> {
|
||||
let boxes = (boxes + 0.5)?;
|
||||
let coords = boxes.reshape(((), 2, 2))?;
|
||||
let corner_embedding = self
|
||||
.pe_layer
|
||||
.forward_with_coords(&coords, self.input_image_size)?;
|
||||
let ce1 = corner_embedding.i((.., 0))?;
|
||||
let ce2 = corner_embedding.i((.., 1))?;
|
||||
let ce1 = (ce1 + self.point_embeddings[2].embeddings())?;
|
||||
let ce2 = (ce2 + self.point_embeddings[3].embeddings())?;
|
||||
Tensor::cat(&[&ce1, &ce2], 1)
|
||||
}
|
||||
|
||||
pub fn forward(
|
||||
&self,
|
||||
points: Option<(&Tensor, &Tensor)>,
|
||||
boxes: Option<&Tensor>,
|
||||
masks: Option<&Tensor>,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
let _enter = self.span.enter();
|
||||
let se_points = match points {
|
||||
Some((coords, labels)) => Some(self.embed_points(coords, labels, boxes.is_none())?),
|
||||
None => None,
|
||||
};
|
||||
let se_boxes = match boxes {
|
||||
Some(boxes) => Some(self.embed_boxes(boxes)?),
|
||||
None => None,
|
||||
};
|
||||
let sparse_embeddings = match (se_points, se_boxes) {
|
||||
(Some(se_points), Some(se_boxes)) => Tensor::cat(&[se_points, se_boxes], 1)?,
|
||||
(Some(se_points), None) => se_points,
|
||||
(None, Some(se_boxes)) => se_boxes,
|
||||
(None, None) => {
|
||||
Tensor::zeros((1, 0, self.embed_dim), DType::F32, &candle::Device::Cpu)?
|
||||
}
|
||||
};
|
||||
|
||||
let dense_embeddings = match masks {
|
||||
None => {
|
||||
let emb = self.no_mask_embed.embeddings();
|
||||
emb.reshape((1, (), 1, 1))?.expand((
|
||||
1,
|
||||
emb.elem_count(),
|
||||
self.image_embedding_size.0,
|
||||
self.image_embedding_size.1,
|
||||
))?
|
||||
}
|
||||
Some(masks) => self.embed_masks(masks)?,
|
||||
};
|
||||
Ok((sparse_embeddings, dense_embeddings))
|
||||
}
|
||||
}
|
@ -1,411 +0,0 @@
|
||||
use candle::{DType, IndexOp, Result, Tensor};
|
||||
use candle_nn::{Module, VarBuilder};
|
||||
|
||||
use crate::model_image_encoder::ImageEncoderViT;
|
||||
use crate::model_mask_decoder::MaskDecoder;
|
||||
use crate::model_prompt_encoder::PromptEncoder;
|
||||
use crate::model_tiny_vit::{tiny_vit_5m, TinyViT};
|
||||
|
||||
const PROMPT_EMBED_DIM: usize = 256;
|
||||
pub const IMAGE_SIZE: usize = 1024;
|
||||
const VIT_PATCH_SIZE: usize = 16;
|
||||
const PRED_IOU_THRESH: f32 = 0.88;
|
||||
const STABILITY_SCORE_OFFSET: f32 = 1.0;
|
||||
const STABILITY_SCORE_THRESHOLD: f32 = 0.95;
|
||||
const MODEL_MASK_THRESHOLD: f32 = 0.0;
|
||||
const CROP_NMS_THRESH: f32 = 0.7;
|
||||
|
||||
#[derive(Debug)]
|
||||
enum ImageEncoder {
|
||||
Original(ImageEncoderViT),
|
||||
TinyViT(TinyViT),
|
||||
}
|
||||
|
||||
impl Module for ImageEncoder {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
match self {
|
||||
Self::Original(vit) => vit.forward(xs),
|
||||
Self::TinyViT(vit) => vit.forward(xs),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Sam {
|
||||
image_encoder: ImageEncoder,
|
||||
prompt_encoder: PromptEncoder,
|
||||
mask_decoder: MaskDecoder,
|
||||
pixel_mean: Tensor,
|
||||
pixel_std: Tensor,
|
||||
}
|
||||
|
||||
impl Sam {
|
||||
pub fn new(
|
||||
encoder_embed_dim: usize,
|
||||
encoder_depth: usize,
|
||||
encoder_num_heads: usize,
|
||||
encoder_global_attn_indexes: &[usize],
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let image_embedding_size = IMAGE_SIZE / VIT_PATCH_SIZE;
|
||||
|
||||
let image_encoder = ImageEncoderViT::new(
|
||||
IMAGE_SIZE,
|
||||
VIT_PATCH_SIZE,
|
||||
3,
|
||||
encoder_embed_dim,
|
||||
encoder_depth,
|
||||
encoder_num_heads,
|
||||
PROMPT_EMBED_DIM,
|
||||
/* qkv_bias */ true,
|
||||
/* use_rel_pos */ true,
|
||||
/* use_abs_pos */ true,
|
||||
/* window_size */ 14,
|
||||
/* global_attn_indexes */ encoder_global_attn_indexes,
|
||||
vb.pp("image_encoder"),
|
||||
)?;
|
||||
let prompt_encoder = PromptEncoder::new(
|
||||
PROMPT_EMBED_DIM,
|
||||
(image_embedding_size, image_embedding_size),
|
||||
(IMAGE_SIZE, IMAGE_SIZE),
|
||||
16,
|
||||
vb.pp("prompt_encoder"),
|
||||
)?;
|
||||
let mask_decoder = MaskDecoder::new(
|
||||
PROMPT_EMBED_DIM,
|
||||
/* num_multitask_outputs */ 3,
|
||||
/* iou_head_depth */ 3,
|
||||
/* iou_head_hidden_dim */ 256,
|
||||
vb.pp("mask_decoder"),
|
||||
)?;
|
||||
let pixel_mean =
|
||||
Tensor::new(&[123.675f32, 116.28, 103.53], vb.device())?.reshape((3, 1, 1))?;
|
||||
let pixel_std =
|
||||
Tensor::new(&[58.395f32, 57.12, 57.375], vb.device())?.reshape((3, 1, 1))?;
|
||||
Ok(Self {
|
||||
image_encoder: ImageEncoder::Original(image_encoder),
|
||||
prompt_encoder,
|
||||
mask_decoder,
|
||||
pixel_std,
|
||||
pixel_mean,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn new_tiny(vb: VarBuilder) -> Result<Self> {
|
||||
let image_embedding_size = IMAGE_SIZE / VIT_PATCH_SIZE;
|
||||
|
||||
let image_encoder = tiny_vit_5m(vb.pp("image_encoder"))?;
|
||||
let prompt_encoder = PromptEncoder::new(
|
||||
PROMPT_EMBED_DIM,
|
||||
(image_embedding_size, image_embedding_size),
|
||||
(IMAGE_SIZE, IMAGE_SIZE),
|
||||
16,
|
||||
vb.pp("prompt_encoder"),
|
||||
)?;
|
||||
let mask_decoder = MaskDecoder::new(
|
||||
PROMPT_EMBED_DIM,
|
||||
/* num_multitask_outputs */ 3,
|
||||
/* iou_head_depth */ 3,
|
||||
/* iou_head_hidden_dim */ 256,
|
||||
vb.pp("mask_decoder"),
|
||||
)?;
|
||||
let pixel_mean =
|
||||
Tensor::new(&[123.675f32, 116.28, 103.53], vb.device())?.reshape((3, 1, 1))?;
|
||||
let pixel_std =
|
||||
Tensor::new(&[58.395f32, 57.12, 57.375], vb.device())?.reshape((3, 1, 1))?;
|
||||
Ok(Self {
|
||||
image_encoder: ImageEncoder::TinyViT(image_encoder),
|
||||
prompt_encoder,
|
||||
mask_decoder,
|
||||
pixel_std,
|
||||
pixel_mean,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(
|
||||
&self,
|
||||
img: &Tensor,
|
||||
point: Option<(f64, f64)>,
|
||||
multimask_output: bool,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
let (_c, original_h, original_w) = img.dims3()?;
|
||||
let img = self.preprocess(img)?.unsqueeze(0)?;
|
||||
let img_embeddings = self.image_encoder.forward(&img)?;
|
||||
let image_pe = self.prompt_encoder.get_dense_pe()?;
|
||||
let points = match point {
|
||||
None => None,
|
||||
Some((x, y)) => {
|
||||
let points = Tensor::new(
|
||||
&[[[x as f32 * original_w as f32, y as f32 * original_h as f32]]],
|
||||
img.device(),
|
||||
)?;
|
||||
let labels = Tensor::ones((1, 1), DType::F32, img.device())?;
|
||||
Some((points, labels))
|
||||
}
|
||||
};
|
||||
let points = points.as_ref().map(|(x, y)| (x, y));
|
||||
let (sparse_prompt_embeddings, dense_prompt_embeddings) =
|
||||
self.prompt_encoder.forward(points, None, None)?;
|
||||
let (low_res_mask, iou_predictions) = self.mask_decoder.forward(
|
||||
&img_embeddings,
|
||||
&image_pe,
|
||||
&sparse_prompt_embeddings,
|
||||
&dense_prompt_embeddings,
|
||||
multimask_output,
|
||||
)?;
|
||||
let mask = low_res_mask
|
||||
.upsample_nearest2d(IMAGE_SIZE, IMAGE_SIZE)?
|
||||
.get(0)?
|
||||
.i((.., ..original_h, ..original_w))?;
|
||||
Ok((mask, iou_predictions))
|
||||
}
|
||||
|
||||
pub fn unpreprocess(&self, img: &Tensor) -> Result<Tensor> {
|
||||
let img = img
|
||||
.broadcast_mul(&self.pixel_std)?
|
||||
.broadcast_add(&self.pixel_mean)?;
|
||||
img.maximum(&img.zeros_like()?)?
|
||||
.minimum(&(img.ones_like()? * 255.)?)
|
||||
}
|
||||
|
||||
pub fn preprocess(&self, img: &Tensor) -> Result<Tensor> {
|
||||
let (_c, h, w) = img.dims3()?;
|
||||
let img = img
|
||||
.to_dtype(DType::F32)?
|
||||
.broadcast_sub(&self.pixel_mean)?
|
||||
.broadcast_div(&self.pixel_std)?;
|
||||
if h > IMAGE_SIZE || w > IMAGE_SIZE {
|
||||
candle::bail!("image is too large ({w}, {h}), maximum size {IMAGE_SIZE}")
|
||||
}
|
||||
let img = img.pad_with_zeros(1, 0, IMAGE_SIZE - h)?;
|
||||
img.pad_with_zeros(2, 0, IMAGE_SIZE - w)
|
||||
}
|
||||
|
||||
fn process_crop(
|
||||
&self,
|
||||
img: &Tensor,
|
||||
cb: CropBox,
|
||||
point_grids: &[(f64, f64)],
|
||||
) -> Result<Vec<candle_examples::object_detection::Bbox<Tensor>>> {
|
||||
// Crop the image and calculate embeddings.
|
||||
let img = img.i((.., cb.y0..cb.y1, cb.x0..cb.x1))?;
|
||||
let img = self.preprocess(&img)?.unsqueeze(0)?;
|
||||
let img_embeddings = self.image_encoder.forward(&img)?;
|
||||
|
||||
let crop_w = cb.x1 - cb.x0;
|
||||
let crop_h = cb.y1 - cb.y0;
|
||||
|
||||
// Generate masks for this crop.
|
||||
let image_pe = self.prompt_encoder.get_dense_pe()?;
|
||||
let points = point_grids
|
||||
.iter()
|
||||
.map(|&(x, y)| vec![x as f32 * crop_w as f32, y as f32 * crop_h as f32])
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mut bboxes = Vec::new();
|
||||
for points in points.chunks(64) {
|
||||
// Run the model on this batch.
|
||||
let points_len = points.len();
|
||||
let in_points = Tensor::new(points.to_vec(), img.device())?.unsqueeze(1)?;
|
||||
let in_labels = Tensor::ones((points_len, 1), DType::F32, img.device())?;
|
||||
let (sparse_prompt_embeddings, dense_prompt_embeddings) =
|
||||
self.prompt_encoder
|
||||
.forward(Some((&in_points, &in_labels)), None, None)?;
|
||||
|
||||
let (low_res_mask, iou_predictions) = self.mask_decoder.forward(
|
||||
&img_embeddings,
|
||||
&image_pe,
|
||||
&sparse_prompt_embeddings,
|
||||
&dense_prompt_embeddings,
|
||||
/* multimask_output */ true,
|
||||
)?;
|
||||
let low_res_mask = low_res_mask.flatten(0, 1)?;
|
||||
let iou_predictions = iou_predictions.flatten(0, 1)?.to_vec1::<f32>()?;
|
||||
let dev = low_res_mask.device();
|
||||
|
||||
for (i, iou) in iou_predictions.iter().enumerate() {
|
||||
// Filter by predicted IoU.
|
||||
if *iou < PRED_IOU_THRESH {
|
||||
continue;
|
||||
}
|
||||
let low_res_mask = low_res_mask.get(i)?;
|
||||
|
||||
// Calculate stability score.
|
||||
let bound = Tensor::new(MODEL_MASK_THRESHOLD + STABILITY_SCORE_OFFSET, dev)?
|
||||
.broadcast_as(low_res_mask.shape())?;
|
||||
let intersections = low_res_mask
|
||||
.ge(&bound)?
|
||||
.to_dtype(DType::F32)?
|
||||
.sum_all()?
|
||||
.to_vec0::<f32>()?;
|
||||
let bound = Tensor::new(MODEL_MASK_THRESHOLD - STABILITY_SCORE_OFFSET, dev)?
|
||||
.broadcast_as(low_res_mask.shape())?;
|
||||
let unions = low_res_mask
|
||||
.ge(&bound)?
|
||||
.to_dtype(DType::F32)?
|
||||
.sum_all()?
|
||||
.to_vec0::<f32>()?;
|
||||
let stability_score = intersections / unions;
|
||||
if stability_score < STABILITY_SCORE_THRESHOLD {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Threshold masks and calculate boxes.
|
||||
let low_res_mask = low_res_mask
|
||||
.ge(&Tensor::new(0f32, dev)?.broadcast_as(low_res_mask.shape())?)?
|
||||
.to_dtype(DType::U32)?;
|
||||
let low_res_mask_per_x = low_res_mask.sum(0)?.to_vec1::<u32>()?;
|
||||
let low_res_mask_per_y = low_res_mask.sum(1)?.to_vec1::<u32>()?;
|
||||
let min_max_x = min_max_indexes(&low_res_mask_per_x);
|
||||
let min_max_y = min_max_indexes(&low_res_mask_per_y);
|
||||
if let Some(((x0, x1), (y0, y1))) = min_max_x.zip(min_max_y) {
|
||||
let bbox = candle_examples::object_detection::Bbox {
|
||||
xmin: x0 as f32,
|
||||
ymin: y0 as f32,
|
||||
xmax: x1 as f32,
|
||||
ymax: y1 as f32,
|
||||
confidence: *iou,
|
||||
data: low_res_mask,
|
||||
};
|
||||
bboxes.push(bbox);
|
||||
}
|
||||
// TODO:
|
||||
// Filter boxes that touch crop boundaries
|
||||
// Compress to RLE.
|
||||
}
|
||||
}
|
||||
|
||||
let mut bboxes = vec![bboxes];
|
||||
// Remove duplicates within this crop.
|
||||
candle_examples::object_detection::non_maximum_suppression(&mut bboxes, CROP_NMS_THRESH);
|
||||
|
||||
// TODO: Return to the original image frame.
|
||||
Ok(bboxes.remove(0))
|
||||
}
|
||||
|
||||
pub fn generate_masks(
|
||||
&self,
|
||||
img: &Tensor,
|
||||
points_per_side: usize,
|
||||
crop_n_layer: usize,
|
||||
crop_overlap_ratio: f64,
|
||||
crop_n_points_downscale_factor: usize,
|
||||
) -> Result<Vec<candle_examples::object_detection::Bbox<Tensor>>> {
|
||||
let (_c, h, w) = img.dims3()?;
|
||||
let point_grids = build_all_layer_point_grids(
|
||||
points_per_side,
|
||||
crop_n_layer,
|
||||
crop_n_points_downscale_factor,
|
||||
);
|
||||
let crop_boxes = generate_crop_boxes((h, w), crop_n_layer, crop_overlap_ratio);
|
||||
let mut bboxes = Vec::new();
|
||||
for crop_box in crop_boxes.into_iter() {
|
||||
let layer_idx = crop_box.layer_idx;
|
||||
let b = self.process_crop(img, crop_box, &point_grids[layer_idx])?;
|
||||
bboxes.extend(b)
|
||||
}
|
||||
// TODO: remove duplicates
|
||||
Ok(bboxes)
|
||||
}
|
||||
}
|
||||
|
||||
// Return the first and last indexes i for which values[i] > 0
|
||||
fn min_max_indexes(values: &[u32]) -> Option<(usize, usize)> {
|
||||
let (mut min_i, mut max_i) = (usize::MAX, usize::MIN);
|
||||
for (i, &s) in values.iter().enumerate() {
|
||||
if s == 0 {
|
||||
continue;
|
||||
}
|
||||
min_i = usize::min(i, min_i);
|
||||
max_i = usize::max(i, max_i);
|
||||
}
|
||||
if max_i < min_i {
|
||||
None
|
||||
} else {
|
||||
Some((min_i, max_i))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct CropBox {
|
||||
x0: usize,
|
||||
y0: usize,
|
||||
x1: usize,
|
||||
y1: usize,
|
||||
layer_idx: usize,
|
||||
}
|
||||
|
||||
impl CropBox {
|
||||
fn new(x0: usize, y0: usize, x1: usize, y1: usize, layer_idx: usize) -> Self {
|
||||
Self {
|
||||
x0,
|
||||
y0,
|
||||
x1,
|
||||
y1,
|
||||
layer_idx,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn generate_crop_boxes(
|
||||
(im_h, im_w): (usize, usize),
|
||||
n_layers: usize,
|
||||
overlap_ratio: f64,
|
||||
) -> Vec<CropBox> {
|
||||
fn crop_len(orig_len: usize, n_crops: usize, overlap: usize) -> usize {
|
||||
f64::ceil((overlap * (n_crops - 1) + orig_len) as f64 / n_crops as f64) as usize
|
||||
}
|
||||
|
||||
let short_side = usize::min(im_h, im_w);
|
||||
|
||||
let mut crop_boxes = Vec::new();
|
||||
|
||||
// Original image.
|
||||
crop_boxes.push(CropBox::new(0, 0, im_w, im_h, 0));
|
||||
|
||||
for layer_idx in 1..=n_layers {
|
||||
let n_crops_per_side = 1 << layer_idx;
|
||||
let overlap = (overlap_ratio * short_side as f64 * 2. / n_crops_per_side as f64) as usize;
|
||||
let crop_w = crop_len(im_w, n_crops_per_side, overlap);
|
||||
let crop_h = crop_len(im_w, n_crops_per_side, overlap);
|
||||
|
||||
for i_x in 0..n_crops_per_side {
|
||||
let x0 = (crop_w - overlap) * i_x;
|
||||
for i_y in 0..n_crops_per_side {
|
||||
let y0 = (crop_h - overlap) * i_y;
|
||||
let x1 = usize::min(im_w, x0 + crop_w);
|
||||
let y1 = usize::min(im_h, y0 + crop_h);
|
||||
crop_boxes.push(CropBox::new(x0, y0, x1, y1, layer_idx));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
crop_boxes
|
||||
}
|
||||
|
||||
// Generates a 2D grid of points evenly spaced in [0,1]x[0,1].
|
||||
fn build_point_grid(n_per_side: usize) -> Vec<(f64, f64)> {
|
||||
let offset = 1f64 / (2 * n_per_side) as f64;
|
||||
let mut points = Vec::with_capacity(n_per_side * n_per_side);
|
||||
for i_x in 0..n_per_side {
|
||||
let x = offset + i_x as f64 / n_per_side as f64;
|
||||
for i_y in 0..n_per_side {
|
||||
let y = offset + i_y as f64 / n_per_side as f64;
|
||||
points.push((x, y))
|
||||
}
|
||||
}
|
||||
points
|
||||
}
|
||||
|
||||
fn build_all_layer_point_grids(
|
||||
n_per_side: usize,
|
||||
n_layers: usize,
|
||||
scale_per_layer: usize,
|
||||
) -> Vec<Vec<(f64, f64)>> {
|
||||
let mut points_by_layer = Vec::with_capacity(n_layers + 1);
|
||||
for i in 0..=n_layers {
|
||||
let n_points = n_per_side / scale_per_layer.pow(i as u32);
|
||||
points_by_layer.push(build_point_grid(n_points))
|
||||
}
|
||||
points_by_layer
|
||||
}
|
@ -1,633 +0,0 @@
|
||||
// Adapted from:
|
||||
// https://github.com/ChaoningZhang/MobileSAM/blob/master/mobile_sam/modeling/tiny_vit_sam.py
|
||||
use candle::{IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{Conv2dConfig, Module, VarBuilder};
|
||||
|
||||
const MBCONV_EXPAND_RATIO: usize = 4;
|
||||
const MLP_RATIO: usize = 4;
|
||||
const LOCAL_CONV_SIZE: usize = 3;
|
||||
const IMG_SIZE: usize = 1024;
|
||||
const IN_CHANNELS: usize = 3;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Conv2dBN {
|
||||
c: candle_nn::Conv2d,
|
||||
bn: candle_nn::BatchNorm,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl Conv2dBN {
|
||||
fn new(in_: usize, out: usize, ks: usize, cfg: Conv2dConfig, vb: VarBuilder) -> Result<Self> {
|
||||
let c = candle_nn::conv2d_no_bias(in_, out, ks, cfg, vb.pp("c"))?;
|
||||
let bn = candle_nn::batch_norm(out, 1e-5, vb.pp("bn"))?;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "conv2d-bn");
|
||||
Ok(Self { c, bn, span })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Conv2dBN {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
xs.apply(&self.c)?.apply(&self.bn)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct PatchEmbed {
|
||||
conv1: Conv2dBN,
|
||||
conv2: Conv2dBN,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl PatchEmbed {
|
||||
fn new(in_chans: usize, embed_dim: usize, vb: VarBuilder) -> Result<Self> {
|
||||
let cfg = candle_nn::Conv2dConfig {
|
||||
stride: 2,
|
||||
padding: 1,
|
||||
..Default::default()
|
||||
};
|
||||
let conv1 = Conv2dBN::new(in_chans, embed_dim / 2, 3, cfg, vb.pp("seq.0"))?;
|
||||
let conv2 = Conv2dBN::new(embed_dim / 2, embed_dim, 3, cfg, vb.pp("seq.2"))?;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "patch-embed");
|
||||
Ok(Self { conv1, conv2, span })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for PatchEmbed {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
xs.apply(&self.conv1)?.gelu()?.apply(&self.conv2)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct MBConv {
|
||||
conv1: Conv2dBN,
|
||||
conv2: Conv2dBN,
|
||||
conv3: Conv2dBN,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl MBConv {
|
||||
fn new(in_: usize, out: usize, expand_ratio: usize, vb: VarBuilder) -> Result<Self> {
|
||||
let hidden = in_ * expand_ratio;
|
||||
let cfg2 = candle_nn::Conv2dConfig {
|
||||
padding: 1,
|
||||
groups: hidden,
|
||||
..Default::default()
|
||||
};
|
||||
let conv1 = Conv2dBN::new(in_, hidden, 1, Default::default(), vb.pp("conv1"))?;
|
||||
let conv2 = Conv2dBN::new(hidden, hidden, 3, cfg2, vb.pp("conv2"))?;
|
||||
let conv3 = Conv2dBN::new(hidden, out, 1, Default::default(), vb.pp("conv3"))?;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "mb-conv");
|
||||
Ok(Self {
|
||||
conv1,
|
||||
conv2,
|
||||
conv3,
|
||||
span,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for MBConv {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let shortcut = xs;
|
||||
let xs = xs
|
||||
.apply(&self.conv1)?
|
||||
.gelu()?
|
||||
.apply(&self.conv2)?
|
||||
.gelu()?
|
||||
.apply(&self.conv3)?;
|
||||
(xs + shortcut)?.gelu()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct PatchMerging {
|
||||
conv1: Conv2dBN,
|
||||
conv2: Conv2dBN,
|
||||
conv3: Conv2dBN,
|
||||
input_resolution: (usize, usize),
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl PatchMerging {
|
||||
fn new(
|
||||
input_resolution: (usize, usize),
|
||||
dim: usize,
|
||||
out: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let stride = if [320, 448, 576].contains(&out) { 1 } else { 2 };
|
||||
let cfg2 = candle_nn::Conv2dConfig {
|
||||
padding: 1,
|
||||
stride,
|
||||
groups: out,
|
||||
..Default::default()
|
||||
};
|
||||
let conv1 = Conv2dBN::new(dim, out, 1, Default::default(), vb.pp("conv1"))?;
|
||||
let conv2 = Conv2dBN::new(out, out, 3, cfg2, vb.pp("conv2"))?;
|
||||
let conv3 = Conv2dBN::new(out, out, 1, Default::default(), vb.pp("conv3"))?;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "patch-merging");
|
||||
Ok(Self {
|
||||
conv1,
|
||||
conv2,
|
||||
conv3,
|
||||
input_resolution,
|
||||
span,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for PatchMerging {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let xs = if xs.rank() == 3 {
|
||||
let (h, w) = self.input_resolution;
|
||||
let b = xs.dim(0)?;
|
||||
xs.reshape((b, h, w, ()))?.permute((0, 3, 1, 2))?
|
||||
} else {
|
||||
xs.clone()
|
||||
};
|
||||
xs.apply(&self.conv1)?
|
||||
.gelu()?
|
||||
.apply(&self.conv2)?
|
||||
.gelu()?
|
||||
.apply(&self.conv3)?
|
||||
.flatten_from(2)?
|
||||
.transpose(1, 2)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct ConvLayer {
|
||||
blocks: Vec<MBConv>,
|
||||
downsample: Option<PatchMerging>,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl ConvLayer {
|
||||
fn new(
|
||||
dim: usize,
|
||||
out: usize,
|
||||
input_resolution: (usize, usize),
|
||||
depth: usize,
|
||||
downsample: bool,
|
||||
conv_expand_ratio: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let vb_b = vb.pp("blocks");
|
||||
let mut blocks = Vec::with_capacity(depth);
|
||||
for index in 0..depth {
|
||||
let block = MBConv::new(dim, dim, conv_expand_ratio, vb_b.pp(index))?;
|
||||
blocks.push(block)
|
||||
}
|
||||
let downsample = if downsample {
|
||||
let downsample = PatchMerging::new(input_resolution, dim, out, vb.pp("downsample"))?;
|
||||
Some(downsample)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let span = tracing::span!(tracing::Level::TRACE, "conv-layer");
|
||||
Ok(Self {
|
||||
blocks,
|
||||
downsample,
|
||||
span,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for ConvLayer {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let mut xs = xs.clone();
|
||||
for block in self.blocks.iter() {
|
||||
xs = block.forward(&xs)?
|
||||
}
|
||||
match &self.downsample {
|
||||
None => Ok(xs),
|
||||
Some(downsample) => downsample.forward(&xs),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Mlp {
|
||||
norm: candle_nn::LayerNorm,
|
||||
fc1: crate::Linear,
|
||||
fc2: crate::Linear,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl Mlp {
|
||||
fn new(in_: usize, hidden: usize, vb: VarBuilder) -> Result<Self> {
|
||||
let norm = candle_nn::layer_norm(in_, 1e-5, vb.pp("norm"))?;
|
||||
let fc1 = crate::linear(vb.pp("fc1"), in_, hidden, true)?;
|
||||
let fc2 = crate::linear(vb.pp("fc2"), hidden, in_, true)?;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "mlp");
|
||||
Ok(Self {
|
||||
norm,
|
||||
fc1,
|
||||
fc2,
|
||||
span,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Mlp {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
xs.apply(&self.norm)?
|
||||
.apply(&self.fc1)?
|
||||
.gelu()?
|
||||
.apply(&self.fc2)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Attention {
|
||||
norm: candle_nn::LayerNorm,
|
||||
qkv: crate::Linear,
|
||||
proj: crate::Linear,
|
||||
ab: Tensor,
|
||||
key_dim: usize,
|
||||
num_heads: usize,
|
||||
d: usize,
|
||||
dh: usize,
|
||||
scale: f64,
|
||||
span: tracing::Span,
|
||||
span_matmul: tracing::Span,
|
||||
span_softmax: tracing::Span,
|
||||
}
|
||||
|
||||
impl Attention {
|
||||
fn new(
|
||||
dim: usize,
|
||||
key_dim: usize,
|
||||
num_heads: usize,
|
||||
attn_ratio: usize,
|
||||
resolution: (usize, usize),
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let d = attn_ratio * key_dim;
|
||||
let dh = d * num_heads;
|
||||
let nh_kd = key_dim * num_heads;
|
||||
let h = dh + nh_kd * 2;
|
||||
let norm = candle_nn::layer_norm(dim, 1e-5, vb.pp("norm"))?;
|
||||
let qkv = crate::linear(vb.pp("qkv"), dim, h, true)?;
|
||||
let proj = crate::linear(vb.pp("proj"), dh, dim, true)?;
|
||||
|
||||
let points = (0..resolution.0)
|
||||
.flat_map(|x| (0..resolution.1).map(move |y| (x as i64, y as i64)))
|
||||
.collect::<Vec<_>>();
|
||||
let mut idxs = Vec::with_capacity(points.len() * points.len());
|
||||
let mut attention_offsets = std::collections::HashMap::new();
|
||||
for &(x1, y1) in points.iter() {
|
||||
for &(x2, y2) in points.iter() {
|
||||
let offset = ((x2 - x1).abs(), (y2 - y1).abs());
|
||||
let l = attention_offsets.len();
|
||||
let idx = attention_offsets.entry(offset).or_insert(l);
|
||||
idxs.push(*idx as u32)
|
||||
}
|
||||
}
|
||||
let attention_biases = vb.get((num_heads, attention_offsets.len()), "attention_biases")?;
|
||||
let idxs = Tensor::new(idxs, attention_biases.device())?;
|
||||
let ab =
|
||||
attention_biases
|
||||
.index_select(&idxs, 1)?
|
||||
.reshape(((), points.len(), points.len()))?;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "attention");
|
||||
let span_matmul = tracing::span!(tracing::Level::TRACE, "attn-matmul");
|
||||
let span_softmax = tracing::span!(tracing::Level::TRACE, "attn-sm");
|
||||
Ok(Self {
|
||||
norm,
|
||||
qkv,
|
||||
proj,
|
||||
ab,
|
||||
key_dim,
|
||||
num_heads,
|
||||
d,
|
||||
dh,
|
||||
scale: 1f64 / (key_dim as f64).sqrt(),
|
||||
span,
|
||||
span_matmul,
|
||||
span_softmax,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Attention {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let (b, n, _) = xs.dims3()?;
|
||||
let xs = xs.apply(&self.norm)?;
|
||||
let qkv = xs.apply(&self.qkv)?.reshape((b, n, self.num_heads, ()))?;
|
||||
let q = qkv
|
||||
.narrow(D::Minus1, 0, self.key_dim)?
|
||||
.permute((0, 2, 1, 3))?
|
||||
.contiguous()?;
|
||||
let k = qkv
|
||||
.narrow(D::Minus1, self.key_dim, self.key_dim)?
|
||||
.permute((0, 2, 1, 3))?
|
||||
.contiguous()?;
|
||||
let v = qkv
|
||||
.narrow(D::Minus1, 2 * self.key_dim, self.d)?
|
||||
.permute((0, 2, 1, 3))?
|
||||
.contiguous()?;
|
||||
let attn = {
|
||||
let _enter = self.span_matmul.enter();
|
||||
(q.matmul(&k.t()?)? * self.scale)?
|
||||
};
|
||||
let attn = attn.broadcast_add(&self.ab)?;
|
||||
let attn = {
|
||||
let _enter = self.span_softmax.enter();
|
||||
candle_nn::ops::softmax_last_dim(&attn)?
|
||||
};
|
||||
let attn = {
|
||||
let _enter = self.span_matmul.enter();
|
||||
attn.matmul(&v)?
|
||||
};
|
||||
attn.transpose(1, 2)?
|
||||
.reshape((b, n, self.dh))?
|
||||
.apply(&self.proj)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct TinyViTBlock {
|
||||
attn: Attention,
|
||||
local_conv: Conv2dBN,
|
||||
mlp: Mlp,
|
||||
window_size: usize,
|
||||
input_resolution: (usize, usize),
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl TinyViTBlock {
|
||||
fn new(
|
||||
dim: usize,
|
||||
input_resolution: (usize, usize),
|
||||
num_heads: usize,
|
||||
window_size: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let head_dim = dim / num_heads;
|
||||
let attn = Attention::new(
|
||||
dim,
|
||||
head_dim,
|
||||
num_heads,
|
||||
1,
|
||||
(window_size, window_size),
|
||||
vb.pp("attn"),
|
||||
)?;
|
||||
let mlp = Mlp::new(dim, dim * MLP_RATIO, vb.pp("mlp"))?;
|
||||
let cfg = candle_nn::Conv2dConfig {
|
||||
padding: LOCAL_CONV_SIZE / 2,
|
||||
groups: dim,
|
||||
..Default::default()
|
||||
};
|
||||
let local_conv = Conv2dBN::new(dim, dim, LOCAL_CONV_SIZE, cfg, vb.pp("local_conv"))?;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "attention");
|
||||
Ok(Self {
|
||||
attn,
|
||||
local_conv,
|
||||
mlp,
|
||||
window_size,
|
||||
input_resolution,
|
||||
span,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for TinyViTBlock {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let (h, w) = self.input_resolution;
|
||||
let (b, l, c) = xs.dims3()?;
|
||||
let res_x = xs;
|
||||
let xs = if h == self.window_size && w == self.window_size {
|
||||
self.attn.forward(xs)?
|
||||
} else {
|
||||
let xs = xs.reshape((b, h, w, c))?;
|
||||
let pad_b = (self.window_size - h % self.window_size) % self.window_size;
|
||||
let pad_r = (self.window_size - w % self.window_size) % self.window_size;
|
||||
|
||||
let xs = if pad_b > 0 {
|
||||
xs.pad_with_zeros(1, 0, pad_b)?
|
||||
} else {
|
||||
xs
|
||||
};
|
||||
let xs = if pad_r > 0 {
|
||||
xs.pad_with_zeros(2, 0, pad_r)?
|
||||
} else {
|
||||
xs
|
||||
};
|
||||
let (p_h, p_w) = (h + pad_b, w + pad_r);
|
||||
let n_h = p_h / self.window_size;
|
||||
let n_w = p_w / self.window_size;
|
||||
let xs = xs
|
||||
.reshape((b, n_h, self.window_size, n_w, self.window_size, c))?
|
||||
.transpose(2, 3)?
|
||||
.reshape((b * n_h * n_w, self.window_size * self.window_size, c))?;
|
||||
let xs = self.attn.forward(&xs)?;
|
||||
let xs = xs
|
||||
.reshape((b, n_h, n_w, self.window_size, self.window_size, c))?
|
||||
.transpose(2, 3)?
|
||||
.reshape((b, p_h, p_w, c))?;
|
||||
let xs = if pad_r > 0 {
|
||||
xs.i((.., .., ..w))?.contiguous()?
|
||||
} else {
|
||||
xs
|
||||
};
|
||||
let xs = if pad_b > 0 {
|
||||
xs.i((.., ..h, ..))?.contiguous()?
|
||||
} else {
|
||||
xs
|
||||
};
|
||||
xs.reshape((b, l, c))?
|
||||
};
|
||||
let xs = (xs + res_x)?;
|
||||
let xs = xs
|
||||
.transpose(1, 2)?
|
||||
.reshape((b, c, h, w))?
|
||||
.apply(&self.local_conv)?
|
||||
.reshape((b, c, l))?
|
||||
.transpose(1, 2)?;
|
||||
&xs + self.mlp.forward(&xs)?
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct BasicLayer {
|
||||
blocks: Vec<TinyViTBlock>,
|
||||
downsample: Option<PatchMerging>,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl BasicLayer {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn new(
|
||||
dim: usize,
|
||||
input_resolution: (usize, usize),
|
||||
depth: usize,
|
||||
num_heads: usize,
|
||||
window_size: usize,
|
||||
downsample: bool,
|
||||
out: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let vb_b = vb.pp("blocks");
|
||||
let mut blocks = Vec::with_capacity(depth);
|
||||
for index in 0..depth {
|
||||
let block = TinyViTBlock::new(
|
||||
dim,
|
||||
input_resolution,
|
||||
num_heads,
|
||||
window_size,
|
||||
vb_b.pp(index),
|
||||
)?;
|
||||
blocks.push(block)
|
||||
}
|
||||
let downsample = if downsample {
|
||||
let downsample = PatchMerging::new(input_resolution, dim, out, vb.pp("downsample"))?;
|
||||
Some(downsample)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let span = tracing::span!(tracing::Level::TRACE, "basic-layer");
|
||||
Ok(Self {
|
||||
blocks,
|
||||
downsample,
|
||||
span,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for BasicLayer {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let mut xs = xs.clone();
|
||||
for block in self.blocks.iter() {
|
||||
xs = block.forward(&xs)?
|
||||
}
|
||||
match &self.downsample {
|
||||
None => Ok(xs),
|
||||
Some(downsample) => downsample.forward(&xs),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct TinyViT {
|
||||
patch_embed: PatchEmbed,
|
||||
layer0: ConvLayer,
|
||||
layers: Vec<BasicLayer>,
|
||||
// norm_head: candle_nn::LayerNorm,
|
||||
// head: candle_nn::Linear,
|
||||
neck_conv1: candle_nn::Conv2d,
|
||||
neck_ln1: crate::LayerNorm2d,
|
||||
neck_conv2: candle_nn::Conv2d,
|
||||
neck_ln2: crate::LayerNorm2d,
|
||||
span: tracing::Span,
|
||||
span_neck: tracing::Span,
|
||||
}
|
||||
|
||||
impl TinyViT {
|
||||
pub fn new(
|
||||
embed_dims: &[usize],
|
||||
depths: &[usize],
|
||||
num_heads: &[usize],
|
||||
window_sizes: &[usize],
|
||||
_num_classes: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let patch_embed = PatchEmbed::new(IN_CHANNELS, embed_dims[0], vb.pp("patch_embed"))?;
|
||||
let patches_resolution = IMG_SIZE / 4;
|
||||
|
||||
let vb_l = vb.pp("layers");
|
||||
let layer0 = ConvLayer::new(
|
||||
/* dim */ embed_dims[0],
|
||||
/* out */ embed_dims[1],
|
||||
/* input_resolution */ (patches_resolution, patches_resolution),
|
||||
/* depth */ depths[0],
|
||||
/* downsample */ true,
|
||||
/* conv_expand_ratio */ MBCONV_EXPAND_RATIO,
|
||||
vb_l.pp(0),
|
||||
)?;
|
||||
|
||||
let num_layers = embed_dims.len();
|
||||
let mut layers = Vec::with_capacity(num_layers - 1);
|
||||
for i_layer in 1..num_layers {
|
||||
let patches_resolution = patches_resolution / (1 << usize::min(i_layer, 2));
|
||||
let layer = BasicLayer::new(
|
||||
/* dim */ embed_dims[i_layer],
|
||||
/* input_resolution */ (patches_resolution, patches_resolution),
|
||||
/* depth */ depths[i_layer],
|
||||
/* num_heads */ num_heads[i_layer],
|
||||
/* window_size */ window_sizes[i_layer],
|
||||
/* downsample */ i_layer < num_layers - 1,
|
||||
/* out */ embed_dims[usize::min(i_layer + 1, num_layers - 1)],
|
||||
vb_l.pp(i_layer),
|
||||
)?;
|
||||
layers.push(layer)
|
||||
}
|
||||
|
||||
let last_embed_dim = embed_dims[embed_dims.len() - 1];
|
||||
// let norm_head = candle_nn::layer_norm(last_embed_dim, 1e-5, vb.pp("norm_head"))?;
|
||||
// let head = candle_nn::linear(last_embed_dim, num_classes, vb.pp("head"))?;
|
||||
let neck_conv1 =
|
||||
candle_nn::conv2d_no_bias(last_embed_dim, 256, 1, Default::default(), vb.pp("neck.0"))?;
|
||||
let neck_ln1 = crate::LayerNorm2d::new(256, 1e-6, vb.pp("neck.1"))?;
|
||||
let cfg = candle_nn::Conv2dConfig {
|
||||
padding: 1,
|
||||
..Default::default()
|
||||
};
|
||||
let neck_conv2 = candle_nn::conv2d_no_bias(256, 256, 3, cfg, vb.pp("neck.2"))?;
|
||||
let neck_ln2 = crate::LayerNorm2d::new(256, 1e-6, vb.pp("neck.3"))?;
|
||||
|
||||
let span = tracing::span!(tracing::Level::TRACE, "tiny-vit");
|
||||
let span_neck = tracing::span!(tracing::Level::TRACE, "neck");
|
||||
Ok(Self {
|
||||
patch_embed,
|
||||
layer0,
|
||||
layers,
|
||||
neck_conv1,
|
||||
neck_ln1,
|
||||
neck_conv2,
|
||||
neck_ln2,
|
||||
span,
|
||||
span_neck,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for TinyViT {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let xs = self.patch_embed.forward(xs)?;
|
||||
let mut xs = self.layer0.forward(&xs)?;
|
||||
for layer in self.layers.iter() {
|
||||
xs = layer.forward(&xs)?
|
||||
}
|
||||
let (b, _, c) = xs.dims3()?;
|
||||
let _enter = self.span_neck.enter();
|
||||
xs.reshape((b, 64, 64, c))?
|
||||
.permute((0, 3, 1, 2))?
|
||||
.apply(&self.neck_conv1)?
|
||||
.apply(&self.neck_ln1)?
|
||||
.apply(&self.neck_conv2)?
|
||||
.apply(&self.neck_ln2)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tiny_vit_5m(vb: VarBuilder) -> Result<TinyViT> {
|
||||
TinyViT::new(
|
||||
/* embed_dims */ &[64, 128, 160, 320],
|
||||
/* depths */ &[2, 2, 6, 2],
|
||||
/* num_heads */ &[2, 4, 5, 10],
|
||||
/* window_sizes */ &[7, 7, 14, 7],
|
||||
/* num_classes */ 1000,
|
||||
vb,
|
||||
)
|
||||
}
|
@ -1,221 +0,0 @@
|
||||
use candle::{Result, Tensor};
|
||||
use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Attention {
|
||||
q_proj: Linear,
|
||||
k_proj: Linear,
|
||||
v_proj: Linear,
|
||||
out_proj: Linear,
|
||||
num_heads: usize,
|
||||
}
|
||||
|
||||
impl Attention {
|
||||
fn new(
|
||||
embedding_dim: usize,
|
||||
num_heads: usize,
|
||||
downsample_rate: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let internal_dim = embedding_dim / downsample_rate;
|
||||
let q_proj = candle_nn::linear(embedding_dim, internal_dim, vb.pp("q_proj"))?;
|
||||
let k_proj = candle_nn::linear(embedding_dim, internal_dim, vb.pp("k_proj"))?;
|
||||
let v_proj = candle_nn::linear(embedding_dim, internal_dim, vb.pp("v_proj"))?;
|
||||
let out_proj = candle_nn::linear(internal_dim, embedding_dim, vb.pp("out_proj"))?;
|
||||
Ok(Self {
|
||||
q_proj,
|
||||
k_proj,
|
||||
v_proj,
|
||||
out_proj,
|
||||
num_heads,
|
||||
})
|
||||
}
|
||||
|
||||
fn separate_heads(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let (b, n, c) = x.dims3()?;
|
||||
x.reshape((b, n, self.num_heads, c / self.num_heads))?
|
||||
.transpose(1, 2)?
|
||||
.contiguous()
|
||||
}
|
||||
|
||||
fn recombine_heads(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let (b, n_heads, n_tokens, c_per_head) = x.dims4()?;
|
||||
x.transpose(1, 2)?
|
||||
.reshape((b, n_tokens, n_heads * c_per_head))
|
||||
}
|
||||
|
||||
fn forward(&self, q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
|
||||
let q = self.q_proj.forward(&q.contiguous()?)?;
|
||||
let k = self.k_proj.forward(&k.contiguous()?)?;
|
||||
let v = self.v_proj.forward(&v.contiguous()?)?;
|
||||
|
||||
let q = self.separate_heads(&q)?;
|
||||
let k = self.separate_heads(&k)?;
|
||||
let v = self.separate_heads(&v)?;
|
||||
|
||||
let (_, _, _, c_per_head) = q.dims4()?;
|
||||
let attn = (q.matmul(&k.t()?)? / (c_per_head as f64).sqrt())?;
|
||||
let attn = candle_nn::ops::softmax_last_dim(&attn)?;
|
||||
|
||||
let out = attn.matmul(&v)?;
|
||||
self.recombine_heads(&out)?.apply(&self.out_proj)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct TwoWayAttentionBlock {
|
||||
self_attn: Attention,
|
||||
norm1: LayerNorm,
|
||||
cross_attn_token_to_image: Attention,
|
||||
norm2: LayerNorm,
|
||||
mlp: crate::MlpBlock,
|
||||
norm3: LayerNorm,
|
||||
norm4: LayerNorm,
|
||||
cross_attn_image_to_token: Attention,
|
||||
skip_first_layer_pe: bool,
|
||||
}
|
||||
|
||||
impl TwoWayAttentionBlock {
|
||||
fn new(
|
||||
embedding_dim: usize,
|
||||
num_heads: usize,
|
||||
mlp_dim: usize,
|
||||
skip_first_layer_pe: bool,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let norm1 = layer_norm(embedding_dim, 1e-5, vb.pp("norm1"))?;
|
||||
let norm2 = layer_norm(embedding_dim, 1e-5, vb.pp("norm2"))?;
|
||||
let norm3 = layer_norm(embedding_dim, 1e-5, vb.pp("norm3"))?;
|
||||
let norm4 = layer_norm(embedding_dim, 1e-5, vb.pp("norm4"))?;
|
||||
let self_attn = Attention::new(embedding_dim, num_heads, 1, vb.pp("self_attn"))?;
|
||||
let cross_attn_token_to_image = Attention::new(
|
||||
embedding_dim,
|
||||
num_heads,
|
||||
2,
|
||||
vb.pp("cross_attn_token_to_image"),
|
||||
)?;
|
||||
let cross_attn_image_to_token = Attention::new(
|
||||
embedding_dim,
|
||||
num_heads,
|
||||
2,
|
||||
vb.pp("cross_attn_image_to_token"),
|
||||
)?;
|
||||
let mlp = crate::MlpBlock::new(
|
||||
embedding_dim,
|
||||
mlp_dim,
|
||||
candle_nn::Activation::Relu,
|
||||
vb.pp("mlp"),
|
||||
)?;
|
||||
Ok(Self {
|
||||
self_attn,
|
||||
norm1,
|
||||
cross_attn_image_to_token,
|
||||
norm2,
|
||||
mlp,
|
||||
norm3,
|
||||
norm4,
|
||||
cross_attn_token_to_image,
|
||||
skip_first_layer_pe,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&self,
|
||||
queries: &Tensor,
|
||||
keys: &Tensor,
|
||||
query_pe: &Tensor,
|
||||
key_pe: &Tensor,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
// Self attention block
|
||||
let queries = if self.skip_first_layer_pe {
|
||||
self.self_attn.forward(queries, queries, queries)?
|
||||
} else {
|
||||
let q = (queries + query_pe)?;
|
||||
let attn_out = self.self_attn.forward(&q, &q, queries)?;
|
||||
(queries + attn_out)?
|
||||
};
|
||||
let queries = self.norm1.forward(&queries)?;
|
||||
|
||||
// Cross attention block, tokens attending to image embedding
|
||||
let q = (&queries + query_pe)?;
|
||||
let k = (keys + key_pe)?;
|
||||
let attn_out = self.cross_attn_token_to_image.forward(&q, &k, keys)?;
|
||||
let queries = (&queries + attn_out)?;
|
||||
let queries = self.norm2.forward(&queries)?;
|
||||
|
||||
// MLP block
|
||||
let mlp_out = self.mlp.forward(&queries);
|
||||
let queries = (queries + mlp_out)?;
|
||||
let queries = self.norm3.forward(&queries)?;
|
||||
|
||||
// Cross attention block, image embedding attending to tokens
|
||||
let q = (&queries + query_pe)?;
|
||||
let k = (keys + key_pe)?;
|
||||
let attn_out = self.cross_attn_image_to_token.forward(&k, &q, &queries)?;
|
||||
let keys = (keys + attn_out)?;
|
||||
let keys = self.norm4.forward(&keys)?;
|
||||
|
||||
Ok((queries, keys))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct TwoWayTransformer {
|
||||
layers: Vec<TwoWayAttentionBlock>,
|
||||
final_attn_token_to_image: Attention,
|
||||
norm_final_attn: LayerNorm,
|
||||
}
|
||||
|
||||
impl TwoWayTransformer {
|
||||
pub fn new(
|
||||
depth: usize,
|
||||
embedding_dim: usize,
|
||||
num_heads: usize,
|
||||
mlp_dim: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let vb_l = vb.pp("layers");
|
||||
let mut layers = Vec::with_capacity(depth);
|
||||
for i in 0..depth {
|
||||
let layer =
|
||||
TwoWayAttentionBlock::new(embedding_dim, num_heads, mlp_dim, i == 0, vb_l.pp(i))?;
|
||||
layers.push(layer)
|
||||
}
|
||||
let final_attn_token_to_image = Attention::new(
|
||||
embedding_dim,
|
||||
num_heads,
|
||||
2,
|
||||
vb.pp("final_attn_token_to_image"),
|
||||
)?;
|
||||
let norm_final_attn = layer_norm(embedding_dim, 1e-5, vb.pp("norm_final_attn"))?;
|
||||
Ok(Self {
|
||||
layers,
|
||||
final_attn_token_to_image,
|
||||
norm_final_attn,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(
|
||||
&self,
|
||||
image_embedding: &Tensor,
|
||||
image_pe: &Tensor,
|
||||
point_embedding: &Tensor,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
let image_embedding = image_embedding.flatten_from(2)?.permute((0, 2, 1))?;
|
||||
let image_pe = image_pe.flatten_from(2)?.permute((0, 2, 1))?;
|
||||
|
||||
let mut queries = point_embedding.clone();
|
||||
let mut keys = image_embedding;
|
||||
|
||||
for layer in self.layers.iter() {
|
||||
(queries, keys) = layer.forward(&queries, &keys, point_embedding, &image_pe)?
|
||||
}
|
||||
|
||||
let q = (&queries + point_embedding)?;
|
||||
let k = (&keys + image_pe)?;
|
||||
let attn_out = self.final_attn_token_to_image.forward(&q, &k, &keys)?;
|
||||
let queries = (queries + attn_out)?.apply(&self.norm_final_attn)?;
|
||||
|
||||
Ok((queries, keys))
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user