Move more models to candle-transformers (#796)

* Move dinov2.

* Move efficientnet.

* Move the quantized llama model.

* Move segment-anything.
This commit is contained in:
Laurent Mazare
2023-09-10 10:20:18 +01:00
committed by GitHub
parent d3f05eae8c
commit 35f72514f5
21 changed files with 773 additions and 759 deletions

View File

@ -9,285 +9,10 @@ extern crate accelerate_src;
use clap::Parser;
use candle::{DType, IndexOp, Result, Tensor, D};
use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
use candle::{DType, IndexOp, D};
use candle_nn::{Module, VarBuilder};
use candle_transformers::models::dinov2;
const IMG_SIZE: usize = 518;
const PATCH_SIZE: usize = 14;
const NUM_CLASSES: usize = 1000;
fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> {
if bias {
candle_nn::linear(in_dim, out_dim, vb)
} else {
candle_nn::linear_no_bias(in_dim, out_dim, vb)
}
}
#[derive(Debug)]
struct Attention {
qkv: Linear,
proj: Linear,
num_heads: usize,
scale: f64,
}
impl Attention {
fn new(
vb: VarBuilder,
dim: usize,
num_heads: usize,
qkv_bias: bool,
proj_bias: bool,
) -> Result<Self> {
let qkv = linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?;
let proj = linear(vb.pp("proj"), dim, dim, proj_bias)?;
let scale = 1. / ((dim / num_heads) as f64).sqrt();
Ok(Self {
qkv,
proj,
num_heads,
scale,
})
}
}
impl Module for Attention {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let (b, n, c) = xs.dims3()?;
let qkv = self
.qkv
.forward(xs)?
.reshape((b, n, 3, self.num_heads, c / self.num_heads))?
.transpose(1, 2)? // 02134
.transpose(0, 1)? // 20134
.transpose(2, 3)?; // 20314
let q = (qkv.i(0)? * self.scale)?;
let k = qkv.i(1)?;
let v = qkv.i(2)?;
let attn = candle_nn::ops::softmax(&q.matmul(&k.t()?)?, D::Minus1)?;
let attn = attn.matmul(&v)?.transpose(1, 2)?.reshape((b, n, c))?;
self.proj.forward(&attn)
}
}
#[derive(Debug)]
struct LayerScale {
gamma: Tensor,
}
impl LayerScale {
fn new(vb: VarBuilder, dim: usize) -> Result<Self> {
let gamma = vb.get(dim, "gamma")?;
Ok(Self { gamma })
}
}
impl Module for LayerScale {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
xs.broadcast_mul(&self.gamma)
}
}
#[derive(Debug)]
struct Mlp {
fc1: Linear,
fc2: Linear,
}
impl Mlp {
fn new(vb: VarBuilder, in_features: usize, hidden_features: usize, bias: bool) -> Result<Self> {
let out_features = in_features;
let fc1 = linear(vb.pp("fc1"), in_features, hidden_features, bias)?;
let fc2 = linear(vb.pp("fc2"), hidden_features, out_features, bias)?;
Ok(Self { fc1, fc2 })
}
}
impl Module for Mlp {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs = self.fc1.forward(xs)?.gelu()?;
self.fc2.forward(&xs)
}
}
#[derive(Debug)]
struct Block {
norm1: LayerNorm,
attn: Attention,
ls1: LayerScale,
norm2: LayerNorm,
mlp: Mlp,
ls2: LayerScale,
}
impl Block {
fn new(vb: VarBuilder, dim: usize, num_heads: usize) -> Result<Self> {
let norm1 = layer_norm(dim, 1e-5, vb.pp("norm1"))?;
let attn = Attention::new(vb.pp("attn"), dim, num_heads, true, true)?;
let ls1 = LayerScale::new(vb.pp("ls1"), dim)?;
let norm2 = layer_norm(dim, 1e-5, vb.pp("norm2"))?;
let mlp = Mlp::new(vb.pp("mlp"), dim, dim * 4, true)?;
let ls2 = LayerScale::new(vb.pp("ls2"), dim)?;
Ok(Self {
norm1,
attn,
ls1,
norm2,
mlp,
ls2,
})
}
}
impl Module for Block {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let residual = xs;
let xs = self
.ls1
.forward(&self.attn.forward(&self.norm1.forward(xs)?)?)?;
let xs = (xs + residual)?;
let residual = &xs;
let xs = self
.ls2
.forward(&self.mlp.forward(&self.norm2.forward(&xs)?)?)?;
xs + residual
}
}
#[derive(Debug)]
struct PatchEmbed {
proj: candle_nn::Conv2d,
patch_size: (usize, usize),
num_patches: usize,
}
impl PatchEmbed {
fn new(
vb: VarBuilder,
img_size: usize,
patch_size: usize,
in_chans: usize,
embed_dim: usize,
) -> Result<Self> {
let config = candle_nn::Conv2dConfig {
stride: patch_size,
..Default::default()
};
let proj = candle_nn::conv2d(in_chans, embed_dim, patch_size, config, vb.pp("proj"))?;
let num_patches = (img_size / patch_size) * (img_size / patch_size);
Ok(Self {
proj,
patch_size: (patch_size, patch_size),
num_patches,
})
}
}
impl Module for PatchEmbed {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let (_b, _c, h, w) = xs.dims4()?;
let (patch_h, patch_w) = self.patch_size;
if (h % patch_h) != 0 {
candle::bail!("image height {h} is not a multiple of patch height {patch_h}")
}
if (w % patch_w) != 0 {
candle::bail!("image width {w} is not a multiple of patch width {patch_w}")
}
let xs = self.proj.forward(xs)?;
let (b, c, h, w) = xs.dims4()?;
// flatten embeddings.
xs.reshape((b, c, h * w))?.transpose(1, 2)
}
}
#[derive(Debug)]
pub struct DinoVisionTransformer {
patch_embed: PatchEmbed,
cls_token: Tensor,
pos_embed: Tensor,
blocks: Vec<Block>,
norm: LayerNorm,
head: Linear,
}
impl DinoVisionTransformer {
pub fn new(vb: VarBuilder, depth: usize, embed_dim: usize, num_heads: usize) -> Result<Self> {
let patch_embed =
PatchEmbed::new(vb.pp("patch_embed"), IMG_SIZE, PATCH_SIZE, 3, embed_dim)?;
let cls_token = vb.get((1, 1, embed_dim), "cls_token")?;
let num_tokens = 1;
let pos_embed = vb.get(
(1, patch_embed.num_patches + num_tokens, embed_dim),
"pos_embed",
)?;
let head = linear(vb.pp("head"), 2 * embed_dim, NUM_CLASSES, true)?;
let norm = layer_norm(embed_dim, 1e-5, vb.pp("norm"))?;
let vb_b = vb.pp("blocks");
let blocks = (0..depth)
.map(|i| Block::new(vb_b.pp(&i.to_string()), embed_dim, num_heads))
.collect::<Result<Vec<_>>>()?;
Ok(Self {
patch_embed,
cls_token,
pos_embed,
blocks,
norm,
head,
})
}
fn interpolate_pos_encoding(&self, xs: &Tensor, w: usize, h: usize) -> Result<Tensor> {
let npatch = xs.dim(1)? - 1;
let n = self.pos_embed.dim(1)? - 1;
let sqrt_n = (n as f64).sqrt();
if npatch == n && w == h {
return Ok(xs.clone());
}
let class_pos_embed = self.pos_embed.i((.., ..1))?;
let patch_pos_embed = self.pos_embed.i((.., 1..))?;
let dim = xs.dim(D::Minus1)?;
let (w0, h0) = ((w / PATCH_SIZE) as f64 + 0.1, (h / PATCH_SIZE) as f64 + 0.1);
let patch_pos_embed = patch_pos_embed
.reshape((1, sqrt_n as usize, sqrt_n as usize, dim))?
.transpose(2, 3)?
.transpose(1, 2)?;
// This uses bicubic interpolation in the original implementation.
let patch_pos_embed = patch_pos_embed.upsample_nearest2d(h0 as usize, w0 as usize)?;
let el_count = patch_pos_embed.shape().elem_count();
let patch_pos_embed =
patch_pos_embed
.transpose(1, 2)?
.transpose(2, 3)?
.reshape((1, el_count / dim, dim))?;
Tensor::cat(&[&class_pos_embed, &patch_pos_embed], 1)
}
fn prepare_tokens_with_mask(&self, xs: &Tensor) -> Result<Tensor> {
let (_b, _nc, w, h) = xs.dims4()?;
let xs = self.patch_embed.forward(xs)?;
let xs = Tensor::cat(&[&self.cls_token, &xs], 1)?;
&xs + &self.interpolate_pos_encoding(&xs, w, h)?
}
}
impl Module for DinoVisionTransformer {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let mut xs = self.prepare_tokens_with_mask(xs)?;
for blk in self.blocks.iter() {
xs = blk.forward(&xs)?
}
let xs = self.norm.forward(&xs)?;
let xs_norm_clstoken = xs.i((.., 0))?;
let xs_norm_patchtokens = xs.i((.., 1..))?.mean(1)?;
let xs = Tensor::cat(&[xs_norm_clstoken, xs_norm_patchtokens], D::Minus1)?;
self.head.forward(&xs)
}
}
pub fn vit_small(vb: VarBuilder) -> Result<DinoVisionTransformer> {
DinoVisionTransformer::new(vb, 12, 384, 6)
}
#[derive(Parser)]
struct Args {
#[arg(long)]
@ -320,7 +45,7 @@ pub fn main() -> anyhow::Result<()> {
let weights = unsafe { candle::safetensors::MmapedFile::new(model_file)? };
let weights = weights.deserialize()?;
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
let model = vit_small(vb)?;
let model = dinov2::vit_small(vb)?;
println!("model built");
let logits = model.forward(&image.unsqueeze(0)?)?;
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?

View File

@ -8,340 +8,11 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use candle::{DType, IndexOp, D};
use candle_nn::{Module, VarBuilder};
use candle_transformers::models::efficientnet::{EfficientNet, MBConvConfig};
use clap::{Parser, ValueEnum};
use candle::{DType, IndexOp, Result, Tensor, D};
use candle_nn as nn;
use nn::{Module, VarBuilder};
// Based on the Python version from torchvision.
// https://github.com/pytorch/vision/blob/0d75d9e5516f446c9c0ef93bd4ed9fea13992d06/torchvision/models/efficientnet.py#L47
#[derive(Debug, Clone, Copy)]
pub struct MBConvConfig {
expand_ratio: f64,
kernel: usize,
stride: usize,
input_channels: usize,
out_channels: usize,
num_layers: usize,
}
fn make_divisible(v: f64, divisor: usize) -> usize {
let min_value = divisor;
let new_v = usize::max(
min_value,
(v + divisor as f64 * 0.5) as usize / divisor * divisor,
);
if (new_v as f64) < 0.9 * v {
new_v + divisor
} else {
new_v
}
}
fn bneck_confs(width_mult: f64, depth_mult: f64) -> Vec<MBConvConfig> {
let bneck_conf = |e, k, s, i, o, n| {
let input_channels = make_divisible(i as f64 * width_mult, 8);
let out_channels = make_divisible(o as f64 * width_mult, 8);
let num_layers = (n as f64 * depth_mult).ceil() as usize;
MBConvConfig {
expand_ratio: e,
kernel: k,
stride: s,
input_channels,
out_channels,
num_layers,
}
};
vec![
bneck_conf(1., 3, 1, 32, 16, 1),
bneck_conf(6., 3, 2, 16, 24, 2),
bneck_conf(6., 5, 2, 24, 40, 2),
bneck_conf(6., 3, 2, 40, 80, 3),
bneck_conf(6., 5, 1, 80, 112, 3),
bneck_conf(6., 5, 2, 112, 192, 4),
bneck_conf(6., 3, 1, 192, 320, 1),
]
}
impl MBConvConfig {
fn b0() -> Vec<Self> {
bneck_confs(1.0, 1.0)
}
fn b1() -> Vec<Self> {
bneck_confs(1.0, 1.1)
}
fn b2() -> Vec<Self> {
bneck_confs(1.1, 1.2)
}
fn b3() -> Vec<Self> {
bneck_confs(1.2, 1.4)
}
fn b4() -> Vec<Self> {
bneck_confs(1.4, 1.8)
}
fn b5() -> Vec<Self> {
bneck_confs(1.6, 2.2)
}
fn b6() -> Vec<Self> {
bneck_confs(1.8, 2.6)
}
fn b7() -> Vec<Self> {
bneck_confs(2.0, 3.1)
}
}
/// Conv2D with same padding.
#[derive(Debug)]
struct Conv2DSame {
conv2d: nn::Conv2d,
s: usize,
k: usize,
}
impl Conv2DSame {
fn new(
vb: VarBuilder,
i: usize,
o: usize,
k: usize,
stride: usize,
groups: usize,
bias: bool,
) -> Result<Self> {
let conv_config = nn::Conv2dConfig {
stride,
groups,
..Default::default()
};
let conv2d = if bias {
nn::conv2d(i, o, k, conv_config, vb)?
} else {
nn::conv2d_no_bias(i, o, k, conv_config, vb)?
};
Ok(Self {
conv2d,
s: stride,
k,
})
}
}
impl Module for Conv2DSame {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let s = self.s;
let k = self.k;
let (_, _, ih, iw) = xs.dims4()?;
let oh = (ih + s - 1) / s;
let ow = (iw + s - 1) / s;
let pad_h = usize::max((oh - 1) * s + k - ih, 0);
let pad_w = usize::max((ow - 1) * s + k - iw, 0);
if pad_h > 0 || pad_w > 0 {
let xs = xs.pad_with_zeros(2, pad_h / 2, pad_h - pad_h / 2)?;
let xs = xs.pad_with_zeros(3, pad_w / 2, pad_w - pad_w / 2)?;
self.conv2d.forward(&xs)
} else {
self.conv2d.forward(xs)
}
}
}
#[derive(Debug)]
struct ConvNormActivation {
conv2d: Conv2DSame,
bn2d: nn::BatchNorm,
activation: bool,
}
impl ConvNormActivation {
fn new(
vb: VarBuilder,
i: usize,
o: usize,
k: usize,
stride: usize,
groups: usize,
) -> Result<Self> {
let conv2d = Conv2DSame::new(vb.pp("0"), i, o, k, stride, groups, false)?;
let bn2d = nn::batch_norm(o, 1e-3, vb.pp("1"))?;
Ok(Self {
conv2d,
bn2d,
activation: true,
})
}
fn no_activation(self) -> Self {
Self {
activation: false,
..self
}
}
}
impl Module for ConvNormActivation {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs = self.conv2d.forward(xs)?;
let xs = self.bn2d.forward(&xs)?;
if self.activation {
swish(&xs)
} else {
Ok(xs)
}
}
}
#[derive(Debug)]
struct SqueezeExcitation {
fc1: Conv2DSame,
fc2: Conv2DSame,
}
impl SqueezeExcitation {
fn new(vb: VarBuilder, in_channels: usize, squeeze_channels: usize) -> Result<Self> {
let fc1 = Conv2DSame::new(vb.pp("fc1"), in_channels, squeeze_channels, 1, 1, 1, true)?;
let fc2 = Conv2DSame::new(vb.pp("fc2"), squeeze_channels, in_channels, 1, 1, 1, true)?;
Ok(Self { fc1, fc2 })
}
}
impl Module for SqueezeExcitation {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let residual = xs;
// equivalent to adaptive_avg_pool2d([1, 1])
let xs = xs.mean_keepdim(D::Minus2)?.mean_keepdim(D::Minus1)?;
let xs = self.fc1.forward(&xs)?;
let xs = swish(&xs)?;
let xs = self.fc2.forward(&xs)?;
let xs = nn::ops::sigmoid(&xs)?;
residual.broadcast_mul(&xs)
}
}
#[derive(Debug)]
struct MBConv {
expand_cna: Option<ConvNormActivation>,
depthwise_cna: ConvNormActivation,
squeeze_excitation: SqueezeExcitation,
project_cna: ConvNormActivation,
config: MBConvConfig,
}
impl MBConv {
fn new(vb: VarBuilder, c: MBConvConfig) -> Result<Self> {
let vb = vb.pp("block");
let exp = make_divisible(c.input_channels as f64 * c.expand_ratio, 8);
let expand_cna = if exp != c.input_channels {
Some(ConvNormActivation::new(
vb.pp("0"),
c.input_channels,
exp,
1,
1,
1,
)?)
} else {
None
};
let start_index = if expand_cna.is_some() { 1 } else { 0 };
let depthwise_cna =
ConvNormActivation::new(vb.pp(start_index), exp, exp, c.kernel, c.stride, exp)?;
let squeeze_channels = usize::max(1, c.input_channels / 4);
let squeeze_excitation =
SqueezeExcitation::new(vb.pp(start_index + 1), exp, squeeze_channels)?;
let project_cna =
ConvNormActivation::new(vb.pp(start_index + 2), exp, c.out_channels, 1, 1, 1)?
.no_activation();
Ok(Self {
expand_cna,
depthwise_cna,
squeeze_excitation,
project_cna,
config: c,
})
}
}
impl Module for MBConv {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let use_res_connect =
self.config.stride == 1 && self.config.input_channels == self.config.out_channels;
let ys = match &self.expand_cna {
Some(expand_cna) => expand_cna.forward(xs)?,
None => xs.clone(),
};
let ys = self.depthwise_cna.forward(&ys)?;
let ys = self.squeeze_excitation.forward(&ys)?;
let ys = self.project_cna.forward(&ys)?;
if use_res_connect {
ys + xs
} else {
Ok(ys)
}
}
}
fn swish(s: &Tensor) -> Result<Tensor> {
s * nn::ops::sigmoid(s)?
}
#[derive(Debug)]
struct EfficientNet {
init_cna: ConvNormActivation,
blocks: Vec<MBConv>,
final_cna: ConvNormActivation,
classifier: nn::Linear,
}
impl EfficientNet {
fn new(p: VarBuilder, configs: Vec<MBConvConfig>, nclasses: usize) -> Result<Self> {
let f_p = p.pp("features");
let first_in_c = configs[0].input_channels;
let last_out_c = configs.last().unwrap().out_channels;
let final_out_c = 4 * last_out_c;
let init_cna = ConvNormActivation::new(f_p.pp(0), 3, first_in_c, 3, 2, 1)?;
let nconfigs = configs.len();
let mut blocks = vec![];
for (index, cnf) in configs.into_iter().enumerate() {
let f_p = f_p.pp(index + 1);
for r_index in 0..cnf.num_layers {
let cnf = if r_index == 0 {
cnf
} else {
MBConvConfig {
input_channels: cnf.out_channels,
stride: 1,
..cnf
}
};
blocks.push(MBConv::new(f_p.pp(r_index), cnf)?)
}
}
let final_cna =
ConvNormActivation::new(f_p.pp(nconfigs + 1), last_out_c, final_out_c, 1, 1, 1)?;
let classifier = nn::linear(final_out_c, nclasses, p.pp("classifier.1"))?;
Ok(Self {
init_cna,
blocks,
final_cna,
classifier,
})
}
}
impl Module for EfficientNet {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let mut xs = self.init_cna.forward(xs)?;
for block in self.blocks.iter() {
xs = block.forward(&xs)?
}
let xs = self.final_cna.forward(&xs)?;
// Equivalent to adaptive_avg_pool2d([1, 1]) -> squeeze(-1) -> squeeze(-1)
let xs = xs.mean(D::Minus1)?.mean(D::Minus1)?;
self.classifier.forward(&xs)
}
}
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Which {
B0,

View File

@ -12,7 +12,7 @@ use candle::quantized::{ggml_file, gguf_file};
use candle::{Device, Tensor};
use candle_transformers::generation::LogitsProcessor;
mod model;
use candle_transformers::models::quantized_llama as model;
use model::ModelWeights;
const DEFAULT_PROMPT: &str = "My favorite theorem is ";

View File

@ -1,371 +0,0 @@
use std::collections::HashMap;
use candle::quantized::QTensor;
use candle::quantized::{ggml_file, gguf_file};
use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::{Embedding, Module};
pub const MAX_SEQ_LEN: usize = 4096;
struct RmsNorm {
inner: candle_nn::LayerNorm,
span: tracing::Span,
}
impl RmsNorm {
fn new(scale: QTensor, eps: f32) -> Result<Self> {
let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
let scale = scale.dequantize(&Device::Cpu)?;
let inner = candle_nn::LayerNorm::rms_norm(scale, eps as f64);
Ok(Self { inner, span })
}
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
self.inner.forward(x)
}
}
// QMatMul wrapper adding some tracing.
struct QMatMul {
inner: candle::quantized::QMatMul,
span: tracing::Span,
}
impl QMatMul {
fn from_qtensor(qtensor: QTensor) -> Self {
let inner = candle::quantized::QMatMul::from_qtensor(qtensor);
let span = tracing::span!(tracing::Level::TRACE, "qmatmul");
Self { inner, span }
}
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
self.inner.forward(xs)
}
}
struct LayerWeights {
attention_wq: QMatMul,
attention_wk: QMatMul,
attention_wv: QMatMul,
attention_wo: QMatMul,
attention_norm: RmsNorm,
feed_forward_w1: QMatMul,
feed_forward_w2: QMatMul,
feed_forward_w3: QMatMul,
ffn_norm: RmsNorm,
n_head: usize,
n_kv_head: usize,
head_dim: usize,
cos: Tensor,
sin: Tensor,
kv_cache: Option<(Tensor, Tensor)>,
span_attn: tracing::Span,
span_rot: tracing::Span,
span_mlp: tracing::Span,
}
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
let shape = mask.shape();
let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
let m = mask.where_cond(&on_true, on_false)?;
Ok(m)
}
impl LayerWeights {
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
let _enter = self.span_rot.enter();
let (b_sz, n_head, seq_len, n_embd) = x.dims4()?;
let cos = self
.cos
.narrow(0, index_pos, seq_len)?
.reshape((seq_len, n_embd / 2, 1))?;
let sin = self
.sin
.narrow(0, index_pos, seq_len)?
.reshape((seq_len, n_embd / 2, 1))?;
let cos = cos.broadcast_as((b_sz, 1, seq_len, n_embd / 2, 1))?;
let sin = sin.broadcast_as((b_sz, 1, seq_len, n_embd / 2, 1))?;
// This mimics the llama.cpp behavior.
// https://github.com/ggerganov/llama.cpp/blob/1f0bccb27929e261744c979bc75114955da49e98/ggml.c#L12104-L12105
// The x0 and x1 value are interleaved on the n_embd (= head_dim) dimension.
// The resulting y0 and y1 are also interleaved with:
// y0 = x0*cos - x1*sin
// y1 = x0*sin + x1*cos
let x = x.reshape((b_sz, n_head, seq_len, n_embd / 2, 2))?;
let x0 = x.narrow(D::Minus1, 0, 1)?;
let x1 = x.narrow(D::Minus1, 1, 1)?;
let y0 = (x0.broadcast_mul(&cos)? - x1.broadcast_mul(&sin)?)?;
let y1 = (x0.broadcast_mul(&sin)? + x1.broadcast_mul(&cos)?)?;
let rope = Tensor::cat(&[y0, y1], D::Minus1)?;
let rope = rope.flatten_from(D::Minus2)?;
Ok(rope)
}
fn forward_attn(&mut self, x: &Tensor, mask: &Tensor, index_pos: usize) -> Result<Tensor> {
let _enter = self.span_attn.enter();
let (b_sz, seq_len, n_embd) = x.dims3()?;
let q = self.attention_wq.forward(x)?;
let k = self.attention_wk.forward(x)?;
let v = self.attention_wv.forward(x)?;
let q = q
.reshape((b_sz, seq_len, self.n_head, self.head_dim))?
.transpose(1, 2)?;
let k = k
.reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
.transpose(1, 2)?;
let v = v
.reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
.transpose(1, 2)?;
let q = self.apply_rotary_emb(&q, index_pos)?;
let k = self.apply_rotary_emb(&k, index_pos)?;
let (k, v) = match &self.kv_cache {
None => (k, v),
Some((k_cache, v_cache)) => {
if index_pos == 0 {
(k, v)
} else {
let k = Tensor::cat(&[k_cache, &k], 2)?.contiguous()?;
let v = Tensor::cat(&[v_cache, &v], 2)?.contiguous()?;
(k, v)
}
}
};
self.kv_cache = Some((k.clone(), v.clone()));
// Support for MQA, useful for 70B models.
let k = self.repeat_kv(k)?;
let v = self.repeat_kv(v)?;
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
let mask = mask.broadcast_as(att.shape())?;
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
let att = candle_nn::ops::softmax(&att, D::Minus1)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
let y = att.matmul(&v.contiguous()?)?;
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
let y = self.attention_wo.forward(&y)?;
Ok(y)
}
fn repeat_kv(&self, x: Tensor) -> Result<Tensor> {
let n_rep = self.n_head / self.n_kv_head;
if n_rep == 1 {
Ok(x)
} else {
let (b_sz, n_kv_head, seq_len, head_dim) = x.dims4()?;
let x = x
.unsqueeze(2)?
.expand((b_sz, n_kv_head, n_rep, seq_len, head_dim))?
.reshape((b_sz, n_kv_head * n_rep, seq_len, head_dim))?;
Ok(x)
}
}
}
pub struct ModelWeights {
tok_embeddings: Embedding,
layers: Vec<LayerWeights>,
norm: RmsNorm,
output: QMatMul,
masks: HashMap<usize, Tensor>,
span: tracing::Span,
span_output: tracing::Span,
}
fn precomput_freqs_cis(head_dim: usize, freq_base: f32) -> Result<(Tensor, Tensor)> {
let theta: Vec<_> = (0..head_dim)
.step_by(2)
.map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32))
.collect();
let theta = Tensor::new(theta.as_slice(), &Device::Cpu)?;
let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, &Device::Cpu)?
.to_dtype(DType::F32)?
.reshape((MAX_SEQ_LEN, 1))?
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
let cos = idx_theta.cos()?;
let sin = idx_theta.sin()?;
Ok((cos, sin))
}
impl ModelWeights {
pub fn from_ggml(mut ct: ggml_file::Content, gqa: usize) -> Result<Self> {
let cpu = &Device::Cpu;
let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize;
let (cos, sin) = precomput_freqs_cis(head_dim, 10000.)?;
let tok_embeddings = ct.remove("tok_embeddings.weight")?;
let tok_embeddings = tok_embeddings.dequantize(cpu)?;
let norm = RmsNorm::new(ct.remove("norm.weight")?, 1e-5)?;
let output = ct.remove("output.weight")?;
let mut layers = Vec::with_capacity(ct.hparams.n_layer as usize);
for layer_idx in 0..ct.hparams.n_layer {
let prefix = format!("layers.{layer_idx}");
let attention_wq = ct.remove(&format!("{prefix}.attention.wq.weight"))?;
let attention_wk = ct.remove(&format!("{prefix}.attention.wk.weight"))?;
let attention_wv = ct.remove(&format!("{prefix}.attention.wv.weight"))?;
let attention_wo = ct.remove(&format!("{prefix}.attention.wo.weight"))?;
let feed_forward_w1 = ct.remove(&format!("{prefix}.feed_forward.w1.weight"))?;
let feed_forward_w2 = ct.remove(&format!("{prefix}.feed_forward.w2.weight"))?;
let feed_forward_w3 = ct.remove(&format!("{prefix}.feed_forward.w3.weight"))?;
let attention_norm = ct.remove(&format!("{prefix}.attention_norm.weight"))?;
let ffn_norm = ct.remove(&format!("{prefix}.ffn_norm.weight"))?;
let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp");
layers.push(LayerWeights {
attention_wq: QMatMul::from_qtensor(attention_wq),
attention_wk: QMatMul::from_qtensor(attention_wk),
attention_wv: QMatMul::from_qtensor(attention_wv),
attention_wo: QMatMul::from_qtensor(attention_wo),
attention_norm: RmsNorm::new(attention_norm, 1e-5)?,
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1),
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2),
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3),
ffn_norm: RmsNorm::new(ffn_norm, 1e-5)?,
n_head: ct.hparams.n_head as usize,
n_kv_head: ct.hparams.n_head as usize / gqa,
head_dim: (ct.hparams.n_embd / ct.hparams.n_head) as usize,
cos: cos.clone(),
sin: sin.clone(),
kv_cache: None,
span_attn,
span_rot,
span_mlp,
})
}
let span = tracing::span!(tracing::Level::TRACE, "model");
let span_output = tracing::span!(tracing::Level::TRACE, "output");
Ok(Self {
tok_embeddings: Embedding::new(tok_embeddings, ct.hparams.n_embd as usize),
layers,
norm,
output: QMatMul::from_qtensor(output),
masks: HashMap::new(),
span,
span_output,
})
}
pub fn from_gguf<R: std::io::Seek + std::io::Read>(
ct: gguf_file::Content,
reader: &mut R,
) -> Result<Self> {
let cpu = &Device::Cpu;
let md_get = |s: &str| match ct.metadata.get(s) {
None => candle::bail!("cannot find {s} in metadata"),
Some(v) => Ok(v),
};
// Parameter extraction from metadata.
let head_count = md_get("llama.attention.head_count")?.to_u32()? as usize;
let head_count_kv = md_get("llama.attention.head_count_kv")?.to_u32()? as usize;
let block_count = md_get("llama.block_count")?.to_u32()? as usize;
let embedding_length = md_get("llama.embedding_length")?.to_u32()? as usize;
let rope_dim = md_get("llama.rope.dimension_count")?.to_u32()? as usize;
// Strangely this value is generally 1e-6 in GGUF file but used to be 1e-5 by default.
let rms_norm_eps = md_get("llama.attention.layer_norm_rms_epsilon")?.to_f32()?;
let rope_freq_base = md_get("llama.rope.freq_base")
.and_then(|m| m.to_f32())
.unwrap_or(10000f32);
let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base)?;
let tok_embeddings = ct.tensor(reader, "token_embd.weight")?;
let tok_embeddings = tok_embeddings.dequantize(cpu)?;
let norm = RmsNorm::new(ct.tensor(reader, "output_norm.weight")?, rms_norm_eps)?;
let output = ct.tensor(reader, "output.weight")?;
let mut layers = Vec::with_capacity(block_count);
for layer_idx in 0..block_count {
let prefix = format!("blk.{layer_idx}");
let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"))?;
let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"))?;
let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"))?;
let attention_wo = ct.tensor(reader, &format!("{prefix}.attn_output.weight"))?;
let feed_forward_w1 = ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"))?;
let feed_forward_w2 = ct.tensor(reader, &format!("{prefix}.ffn_down.weight"))?;
let feed_forward_w3 = ct.tensor(reader, &format!("{prefix}.ffn_up.weight"))?;
let attention_norm = ct.tensor(reader, &format!("{prefix}.attn_norm.weight"))?;
let ffn_norm = ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"))?;
let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp");
layers.push(LayerWeights {
attention_wq: QMatMul::from_qtensor(attention_wq),
attention_wk: QMatMul::from_qtensor(attention_wk),
attention_wv: QMatMul::from_qtensor(attention_wv),
attention_wo: QMatMul::from_qtensor(attention_wo),
attention_norm: RmsNorm::new(attention_norm, rms_norm_eps)?,
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1),
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2),
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3),
ffn_norm: RmsNorm::new(ffn_norm, rms_norm_eps)?,
n_head: head_count,
n_kv_head: head_count_kv,
head_dim: embedding_length / head_count,
cos: cos.clone(),
sin: sin.clone(),
kv_cache: None,
span_attn,
span_rot,
span_mlp,
})
}
let span = tracing::span!(tracing::Level::TRACE, "model");
let span_output = tracing::span!(tracing::Level::TRACE, "output");
Ok(Self {
tok_embeddings: Embedding::new(tok_embeddings, embedding_length),
layers,
norm,
output: QMatMul::from_qtensor(output),
masks: HashMap::new(),
span,
span_output,
})
}
fn mask(&mut self, t: usize) -> Result<Tensor> {
if let Some(mask) = self.masks.get(&t) {
Ok(mask.clone())
} else {
let mask: Vec<_> = (0..t)
.flat_map(|i| (0..t).map(move |j| u8::from(j > i)))
.collect();
let mask = Tensor::from_slice(&mask, (t, t), &Device::Cpu)?;
self.masks.insert(t, mask.clone());
Ok(mask)
}
}
pub fn forward(&mut self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
let (_b_sz, seq_len) = x.dims2()?;
let mask = self.mask(seq_len)?;
let _enter = self.span.enter();
let mut layer_in = self.tok_embeddings.forward(x)?;
for layer in self.layers.iter_mut() {
let x = layer_in;
let residual = &x;
let x = layer.attention_norm.forward(&x)?;
let attn = layer.forward_attn(&x, &mask, index_pos)?;
let x = (attn + residual)?;
// MLP
let _enter = layer.span_mlp.enter();
let residual = &x;
let x = layer.ffn_norm.forward(&x)?;
let w1 = layer.feed_forward_w1.forward(&x)?;
let w3 = layer.feed_forward_w3.forward(&x)?;
let mlp = layer
.feed_forward_w2
.forward(&(candle_nn::ops::silu(&w1)? * w3)?)?;
layer_in = (mlp + residual)?;
}
let x = self.norm.forward(&layer_in)?;
let x = x.i((.., seq_len - 1, ..))?;
let _enter = self.span_output.enter();
self.output.forward(&x)
}
}

View File

@ -7,108 +7,11 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
pub mod model_image_encoder;
pub mod model_mask_decoder;
pub mod model_prompt_encoder;
pub mod model_sam;
pub mod model_tiny_vit;
pub mod model_transformer;
use candle::{DType, Result, Tensor};
use candle_nn::{Module, VarBuilder};
use candle::DType;
use candle_nn::VarBuilder;
use candle_transformers::models::segment_anything::sam;
use clap::Parser;
pub fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> {
let inner = if bias {
candle_nn::linear(in_dim, out_dim, vb)?
} else {
candle_nn::linear_no_bias(in_dim, out_dim, vb)?
};
let span = tracing::span!(tracing::Level::TRACE, "linear");
Ok(Linear { inner, span })
}
#[derive(Debug)]
pub struct LayerNorm2d {
weight: Tensor,
bias: Tensor,
num_channels: usize,
eps: f64,
}
impl LayerNorm2d {
pub fn new(num_channels: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
let weight = vb.get(num_channels, "weight")?;
let bias = vb.get(num_channels, "bias")?;
Ok(Self {
weight,
bias,
num_channels,
eps,
})
}
}
impl Module for LayerNorm2d {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let u = xs.mean_keepdim(1)?;
let xs = xs.broadcast_sub(&u)?;
let s = xs.sqr()?.mean_keepdim(1)?;
let xs = xs.broadcast_div(&(s + self.eps)?.sqrt()?)?;
xs.broadcast_mul(&self.weight.reshape((1, self.num_channels, 1, 1))?)?
.broadcast_add(&self.bias.reshape((1, self.num_channels, 1, 1))?)
}
}
#[derive(Debug)]
pub struct MlpBlock {
lin1: Linear,
lin2: Linear,
activation: candle_nn::Activation,
span: tracing::Span,
}
impl MlpBlock {
pub fn new(
embedding_dim: usize,
mlp_dim: usize,
activation: candle_nn::Activation,
vb: VarBuilder,
) -> Result<Self> {
let lin1 = linear(vb.pp("lin1"), embedding_dim, mlp_dim, true)?;
let lin2 = linear(vb.pp("lin2"), mlp_dim, embedding_dim, true)?;
let span = tracing::span!(tracing::Level::TRACE, "mlp-block");
Ok(Self {
lin1,
lin2,
activation,
span,
})
}
}
impl Module for MlpBlock {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
xs.apply(&self.lin1)?
.apply(&self.activation)?
.apply(&self.lin2)
}
}
#[derive(Debug)]
pub struct Linear {
inner: candle_nn::Linear,
span: tracing::Span,
}
impl Module for Linear {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
self.inner.forward(x)
}
}
#[derive(Parser)]
struct Args {
#[arg(long)]
@ -173,7 +76,7 @@ pub fn main() -> anyhow::Result<()> {
let (_c, h, w) = image.dims3()?;
(image, h, w)
} else {
let (image, h, w) = candle_examples::load_image(&args.image, Some(model_sam::IMAGE_SIZE))?;
let (image, h, w) = candle_examples::load_image(&args.image, Some(sam::IMAGE_SIZE))?;
(image.to_device(&device)?, h, w)
};
println!("loaded image {image:?}");
@ -195,9 +98,9 @@ pub fn main() -> anyhow::Result<()> {
let weights = weights.deserialize()?;
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
let sam = if args.use_tiny {
model_sam::Sam::new_tiny(vb)? // tiny vit_t
sam::Sam::new_tiny(vb)? // tiny vit_t
} else {
model_sam::Sam::new(768, 12, 12, &[2, 5, 8, 11], vb)? // sam_vit_b
sam::Sam::new(768, 12, 12, &[2, 5, 8, 11], vb)? // sam_vit_b
};
if args.generate_masks {

View File

@ -1,483 +0,0 @@
use candle::{DType, IndexOp, Result, Tensor};
use candle_nn::{layer_norm, LayerNorm, Module, VarBuilder};
#[derive(Debug)]
struct PatchEmbed {
proj: candle_nn::Conv2d,
span: tracing::Span,
}
impl PatchEmbed {
fn new(
in_chans: usize,
embed_dim: usize,
k_size: usize,
stride: usize,
padding: usize,
vb: VarBuilder,
) -> Result<Self> {
let cfg = candle_nn::Conv2dConfig {
stride,
padding,
..Default::default()
};
let proj = candle_nn::conv2d(in_chans, embed_dim, k_size, cfg, vb.pp("proj"))?;
let span = tracing::span!(tracing::Level::TRACE, "patch-embed");
Ok(Self { proj, span })
}
}
impl Module for PatchEmbed {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
xs.apply(&self.proj)?.permute((0, 2, 3, 1))
}
}
// A custom op to make add_decomposed_rel_pos faster. Most of the time is spent on the final
// addition in the case where b = 12, q_h = q_w = 4096, k_h = k_w = 4096
// (attn.reshape((b, q_h, q_w, k_h, k_w))?
// + rel_h.unsqueeze(4)?.broadcast_add(&rel_w.unsqueeze(3)?)?)?
// .reshape((b, q_h * q_w, k_h * k_w))
// Ideally we would perform this operation in place but this is not supported in candle at the
// moment. We should also investigate using f16 rather than f32.
struct Add3(usize, usize, usize, usize, usize);
impl candle::CustomOp3 for Add3 {
fn name(&self) -> &'static str {
"add3"
}
fn cpu_fwd(
&self,
s1: &candle::CpuStorage,
l1: &candle::Layout,
s2: &candle::CpuStorage,
l2: &candle::Layout,
s3: &candle::CpuStorage,
l3: &candle::Layout,
) -> Result<(candle::CpuStorage, candle::Shape)> {
use rayon::prelude::*;
let Add3(b, q_h, q_w, k_h, k_w) = *self;
let s1 = s1.as_slice::<f32>()?;
let s1 = match l1.contiguous_offsets() {
None => candle::bail!("input1 has to be contiguous"),
Some((o1, o2)) => &s1[o1..o2],
};
let s2 = s2.as_slice::<f32>()?;
let s2 = match l2.contiguous_offsets() {
None => candle::bail!("input2 has to be contiguous"),
Some((o1, o2)) => &s2[o1..o2],
};
let s3 = s3.as_slice::<f32>()?;
let s3 = match l3.contiguous_offsets() {
None => candle::bail!("input3 has to be contiguous"),
Some((o1, o2)) => &s3[o1..o2],
};
let mut dst = vec![0f32; b * q_h * q_w * k_h * k_w];
dst.par_chunks_exact_mut(k_h * k_w)
.enumerate()
.for_each(|(b_idx, dst)| {
let s1_idx = b_idx * k_h * k_w;
let s2_idx = b_idx * k_h;
let s3_idx = b_idx * k_w;
for h_idx in 0..k_h {
let s1_idx = s1_idx + h_idx * k_w;
let s2_idx = s2_idx + h_idx;
let dst_idx = h_idx * k_w;
for w_idx in 0..k_w {
let s1_idx = s1_idx + w_idx;
let s3_idx = s3_idx + w_idx;
let dst_idx = dst_idx + w_idx;
dst[dst_idx] = s1[s1_idx] + s2[s2_idx] + s3[s3_idx]
}
}
});
let dst = candle::WithDType::to_cpu_storage_owned(dst);
Ok((dst, (b, q_h * q_w, k_h * k_w).into()))
}
}
#[derive(Debug)]
struct Attention {
qkv: crate::Linear,
proj: crate::Linear,
num_heads: usize,
scale: f64,
rel_pos_hw: Option<(Tensor, Tensor)>,
span: tracing::Span,
span_matmul: tracing::Span,
span_rel_pos: tracing::Span,
span_softmax: tracing::Span,
}
impl Attention {
fn new(
dim: usize,
num_heads: usize,
qkv_bias: bool,
use_rel_pos: bool,
input_size: (usize, usize),
vb: VarBuilder,
) -> Result<Self> {
let span = tracing::span!(tracing::Level::TRACE, "attention");
let span_matmul = tracing::span!(tracing::Level::TRACE, "attn-matmul");
let span_rel_pos = tracing::span!(tracing::Level::TRACE, "attn-rel-pos");
let span_softmax = tracing::span!(tracing::Level::TRACE, "attn-sm");
let qkv = crate::linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?;
let proj = crate::linear(vb.pp("proj"), dim, dim, true)?;
let head_dim = dim / num_heads;
let scale = 1. / (head_dim as f64).sqrt();
let rel_pos_hw = if use_rel_pos {
let h = vb.get((2 * input_size.0 - 1, head_dim), "rel_pos_h")?;
let w = vb.get((2 * input_size.1 - 1, head_dim), "rel_pos_w")?;
Some((h, w))
} else {
None
};
Ok(Self {
qkv,
proj,
num_heads,
scale,
rel_pos_hw,
span,
span_matmul,
span_rel_pos,
span_softmax,
})
}
fn add_decomposed_rel_pos(
&self,
attn: Tensor,
q: &Tensor,
(q_h, q_w): (usize, usize),
(k_h, k_w): (usize, usize),
) -> Result<Tensor> {
match &self.rel_pos_hw {
Some((rel_pos_h, rel_pos_w)) => {
let r_h = get_rel_pos(q_h, k_h, rel_pos_h)?;
let r_w = get_rel_pos(q_w, k_w, rel_pos_w)?;
let (b, _, dim) = q.dims3()?;
let r_q = q.reshape((b, q_h, q_w, dim))?;
// rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
let rel_h = r_q.matmul(&r_h.broadcast_left(b)?.t()?.contiguous()?)?;
// rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
let rel_w = r_q
.transpose(1, 2)? // -> bwhc
.contiguous()?
.matmul(&r_w.broadcast_left(b)?.t()?.contiguous()?)? // bwhc,bwck -> bwhk
.transpose(1, 2)?
.contiguous()?;
if attn.device().is_cpu() {
let op = Add3(b, q_h, q_w, k_h, k_w);
attn.apply_op3_no_bwd(&rel_h, &rel_w, &op)
} else {
(attn.reshape((b, q_h, q_w, k_h, k_w))?
+ rel_h.unsqueeze(4)?.broadcast_add(&rel_w.unsqueeze(3)?)?)?
.reshape((b, q_h * q_w, k_h * k_w))
}
}
None => Ok(attn),
}
}
}
fn get_rel_pos(q_size: usize, k_size: usize, rel_pos: &Tensor) -> Result<Tensor> {
let max_rel_dist = 2 * usize::max(q_size, k_size) - 1;
let dev = rel_pos.device();
let rel_pos_resized = if rel_pos.dim(0)? != max_rel_dist {
todo!("interpolation")
} else {
rel_pos
};
let q_coords = Tensor::arange(0u32, q_size as u32, dev)?
.reshape((q_size, 1))?
.to_dtype(DType::F32)?;
let k_coords = Tensor::arange(0u32, k_size as u32, dev)?
.reshape((1, k_size))?
.to_dtype(DType::F32)?;
let q_coords = (q_coords * f64::max(1f64, k_size as f64 / q_size as f64))?;
let k_coords = (k_coords * f64::max(1f64, q_size as f64 / k_size as f64))?;
let relative_coords = (q_coords.broadcast_sub(&k_coords)?
+ (k_size as f64 - 1.) * f64::max(1f64, q_size as f64 / k_size as f64))?;
let (d1, d2) = relative_coords.dims2()?;
let relative_coords = relative_coords.to_dtype(DType::U32)?;
rel_pos_resized
.index_select(&relative_coords.reshape(d1 * d2)?, 0)?
.reshape((d1, d2, ()))
}
impl Module for Attention {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let (b, h, w, c) = xs.dims4()?;
let qkv = self
.qkv
.forward(&xs.flatten_to(1)?)?
.reshape((b, h * w, 3, self.num_heads, c / self.num_heads))?
.permute((2, 0, 3, 1, 4))?
.reshape((3, b * self.num_heads, h * w, c / self.num_heads))?;
let q = qkv.i(0)?;
let k = qkv.i(1)?;
let v = qkv.i(2)?;
let attn = {
let _enter = self.span_matmul.enter();
(&q * self.scale)?.matmul(&k.t()?)?
};
let attn = {
let _enter = self.span_rel_pos.enter();
self.add_decomposed_rel_pos(attn, &q, (h, w), (h, w))?
};
let attn = {
let _enter = self.span_softmax.enter();
candle_nn::ops::softmax_last_dim(&attn)?
};
let attn = {
let _enter = self.span_matmul.enter();
attn.matmul(&v)?
};
let attn = attn
.reshape((b, self.num_heads, h, w, c / self.num_heads))?
.permute((0, 2, 3, 1, 4))?
.reshape((b, h * w, c))?;
self.proj.forward(&attn)?.reshape((b, h, w, c))
}
}
#[derive(Debug)]
struct Block {
norm1: LayerNorm,
attn: Attention,
norm2: LayerNorm,
mlp: crate::MlpBlock,
window_size: usize,
span: tracing::Span,
}
impl Block {
fn new(
dim: usize,
num_heads: usize,
qkv_bias: bool,
use_rel_pos: bool,
window_size: usize,
input_size: (usize, usize),
vb: VarBuilder,
) -> Result<Self> {
let norm1 = layer_norm(dim, 1e-6, vb.pp("norm1"))?;
let norm2 = layer_norm(dim, 1e-6, vb.pp("norm2"))?;
let input_size_attn = if window_size == 0 {
input_size
} else {
(window_size, window_size)
};
let attn = Attention::new(
dim,
num_heads,
qkv_bias,
use_rel_pos,
input_size_attn,
vb.pp("attn"),
)?;
let mlp = crate::MlpBlock::new(dim, dim * 4, candle_nn::Activation::Gelu, vb.pp("mlp"))?;
let span = tracing::span!(tracing::Level::TRACE, "ie-block");
Ok(Self {
norm1,
attn,
norm2,
mlp,
window_size,
span,
})
}
}
fn window_partition(xs: Tensor, window_size: usize) -> Result<(Tensor, (usize, usize))> {
let (b, h, w, c) = xs.dims4()?;
let pad_h = (window_size - h % window_size) % window_size;
let pad_w = (window_size - w % window_size) % window_size;
let xs = if pad_h > 0 {
xs.pad_with_zeros(1, 0, pad_h)?
} else {
xs
};
let xs = if pad_w > 0 {
xs.pad_with_zeros(2, 0, pad_w)?
} else {
xs
};
let (h_p, w_p) = (h + pad_h, w + pad_w);
let windows = xs
.reshape((
b,
h_p / window_size,
window_size,
w_p / window_size,
window_size,
c,
))?
.transpose(2, 3)?
.contiguous()?
.flatten_to(2)?;
Ok((windows, (h_p, w_p)))
}
fn window_unpartition(
windows: Tensor,
window_size: usize,
(h_p, w_p): (usize, usize),
(h, w): (usize, usize),
) -> Result<Tensor> {
let b = windows.dim(0)? / (h_p * w_p / window_size / window_size);
let xs = windows
.reshape((
b,
h_p / window_size,
w_p / window_size,
window_size,
window_size,
windows.elem_count() / b / h_p / w_p,
))?
.transpose(2, 3)?
.contiguous()?
.reshape((b, h_p, w_p, ()))?;
let xs = if h_p > h { xs.narrow(1, 0, h)? } else { xs };
let xs = if w_p > w { xs.narrow(2, 0, w)? } else { xs };
Ok(xs)
}
impl Module for Block {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let shortcut = xs;
let xs = self.norm1.forward(xs)?;
let hw = (xs.dim(1)?, xs.dim(2)?);
let (xs, pad_hw) = if self.window_size > 0 {
window_partition(xs, self.window_size)?
} else {
(xs, (0, 0))
};
let xs = self.attn.forward(&xs)?;
let xs = if self.window_size > 0 {
window_unpartition(xs, self.window_size, pad_hw, hw)?
} else {
xs
};
let xs = (xs + shortcut)?;
&xs + xs.apply(&self.norm2)?.apply(&self.mlp)?
}
}
#[derive(Debug)]
pub struct ImageEncoderViT {
patch_embed: PatchEmbed,
blocks: Vec<Block>,
neck_conv1: candle_nn::Conv2d,
neck_ln1: crate::LayerNorm2d,
neck_conv2: candle_nn::Conv2d,
neck_ln2: crate::LayerNorm2d,
pos_embed: Option<Tensor>,
span: tracing::Span,
}
impl ImageEncoderViT {
#[allow(clippy::too_many_arguments)]
pub fn new(
img_size: usize,
patch_size: usize,
in_chans: usize,
embed_dim: usize,
depth: usize,
num_heads: usize,
out_chans: usize,
qkv_bias: bool,
use_rel_pos: bool,
use_abs_pos: bool,
window_size: usize,
global_attn_indexes: &[usize],
vb: VarBuilder,
) -> Result<Self> {
let patch_embed = PatchEmbed::new(
in_chans,
embed_dim,
patch_size,
patch_size,
0,
vb.pp("patch_embed"),
)?;
let mut blocks = Vec::with_capacity(depth);
let vb_b = vb.pp("blocks");
for i in 0..depth {
let window_size = if global_attn_indexes.contains(&i) {
0
} else {
window_size
};
let block = Block::new(
embed_dim,
num_heads,
qkv_bias,
use_rel_pos,
window_size,
(img_size / patch_size, img_size / patch_size),
vb_b.pp(i),
)?;
blocks.push(block)
}
let neck_conv1 = candle_nn::conv2d_no_bias(
embed_dim,
out_chans,
1,
Default::default(),
vb.pp("neck.0"),
)?;
let neck_ln1 = crate::LayerNorm2d::new(out_chans, 1e-6, vb.pp("neck.1"))?;
let cfg = candle_nn::Conv2dConfig {
padding: 1,
..Default::default()
};
let neck_conv2 = candle_nn::conv2d_no_bias(out_chans, out_chans, 3, cfg, vb.pp("neck.2"))?;
let neck_ln2 = crate::LayerNorm2d::new(out_chans, 1e-6, vb.pp("neck.3"))?;
let pos_embed = if use_abs_pos {
let p = vb.get(
(1, img_size / patch_size, img_size / patch_size, embed_dim),
"pos_embed",
)?;
Some(p)
} else {
None
};
let span = tracing::span!(tracing::Level::TRACE, "image-encoder-vit");
Ok(Self {
patch_embed,
blocks,
neck_conv1,
neck_ln1,
neck_conv2,
neck_ln2,
pos_embed,
span,
})
}
}
impl Module for ImageEncoderViT {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let xs = self.patch_embed.forward(xs)?;
let mut xs = match &self.pos_embed {
Some(pos_embed) => (xs + pos_embed)?,
None => xs,
};
for block in self.blocks.iter() {
xs = block.forward(&xs)?
}
xs.permute((0, 3, 1, 2))?
.apply(&self.neck_conv1)?
.apply(&self.neck_ln1)?
.apply(&self.neck_conv2)?
.apply(&self.neck_ln2)
}
}

View File

@ -1,239 +0,0 @@
use candle::{IndexOp, Result, Tensor};
use candle_nn::{Module, VarBuilder};
use crate::model_transformer::TwoWayTransformer;
#[derive(Debug)]
struct MlpMaskDecoder {
layers: Vec<crate::Linear>,
sigmoid_output: bool,
span: tracing::Span,
}
impl MlpMaskDecoder {
fn new(
input_dim: usize,
hidden_dim: usize,
output_dim: usize,
num_layers: usize,
sigmoid_output: bool,
vb: VarBuilder,
) -> Result<Self> {
let mut layers = Vec::with_capacity(num_layers);
let vb = vb.pp("layers");
for i in 0..num_layers {
let in_dim = if i == 0 { input_dim } else { hidden_dim };
let out_dim = if i + 1 == num_layers {
output_dim
} else {
hidden_dim
};
let layer = crate::linear(vb.pp(i), in_dim, out_dim, true)?;
layers.push(layer)
}
let span = tracing::span!(tracing::Level::TRACE, "mlp-mask-decoder");
Ok(Self {
layers,
sigmoid_output,
span,
})
}
}
impl Module for MlpMaskDecoder {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let mut xs = xs.clone();
for (i, layer) in self.layers.iter().enumerate() {
xs = layer.forward(&xs)?;
if i + 1 < self.layers.len() {
xs = xs.relu()?
}
}
if self.sigmoid_output {
candle_nn::ops::sigmoid(&xs)
} else {
Ok(xs)
}
}
}
#[derive(Debug)]
pub struct MaskDecoder {
iou_token: candle_nn::Embedding,
mask_tokens: candle_nn::Embedding,
iou_prediction_head: MlpMaskDecoder,
output_upscaling_conv1: candle_nn::ConvTranspose2d,
output_upscaling_ln: crate::LayerNorm2d,
output_upscaling_conv2: candle_nn::ConvTranspose2d,
num_mask_tokens: usize,
output_hypernetworks_mlps: Vec<MlpMaskDecoder>,
transformer: TwoWayTransformer,
span: tracing::Span,
}
impl MaskDecoder {
pub fn new(
transformer_dim: usize,
num_multimask_outputs: usize,
iou_head_depth: usize,
iou_head_hidden_dim: usize,
vb: VarBuilder,
) -> Result<Self> {
let num_mask_tokens = num_multimask_outputs + 1;
let iou_prediction_head = MlpMaskDecoder::new(
transformer_dim,
iou_head_hidden_dim,
num_mask_tokens,
iou_head_depth,
false,
vb.pp("iou_prediction_head"),
)?;
let iou_token = candle_nn::embedding(1, transformer_dim, vb.pp("iou_token"))?;
let mask_tokens =
candle_nn::embedding(num_mask_tokens, transformer_dim, vb.pp("mask_tokens"))?;
let cfg = candle_nn::ConvTranspose2dConfig {
stride: 2,
..Default::default()
};
let output_upscaling_conv1 = candle_nn::conv_transpose2d(
transformer_dim,
transformer_dim / 4,
2,
cfg,
vb.pp("output_upscaling.0"),
)?;
let output_upscaling_ln =
crate::LayerNorm2d::new(transformer_dim / 4, 1e-6, vb.pp("output_upscaling.1"))?;
let output_upscaling_conv2 = candle_nn::conv_transpose2d(
transformer_dim / 4,
transformer_dim / 8,
2,
cfg,
vb.pp("output_upscaling.3"),
)?;
let mut output_hypernetworks_mlps = Vec::with_capacity(num_mask_tokens);
let vb_o = vb.pp("output_hypernetworks_mlps");
for i in 0..num_mask_tokens {
let mlp = MlpMaskDecoder::new(
transformer_dim,
transformer_dim,
transformer_dim / 8,
3,
false,
vb_o.pp(i),
)?;
output_hypernetworks_mlps.push(mlp)
}
let transformer = TwoWayTransformer::new(
/* depth */ 2,
/* embedding_dim */ transformer_dim,
/* num_heads */ 8,
/* mlp_dim */ 2048,
vb.pp("transformer"),
)?;
let span = tracing::span!(tracing::Level::TRACE, "mask-decoder");
Ok(Self {
iou_token,
mask_tokens,
iou_prediction_head,
output_upscaling_conv1,
output_upscaling_ln,
output_upscaling_conv2,
num_mask_tokens,
output_hypernetworks_mlps,
transformer,
span,
})
}
pub fn forward(
&self,
image_embeddings: &Tensor,
image_pe: &Tensor,
sparse_prompt_embeddings: &Tensor,
dense_prompt_embeddings: &Tensor,
multimask_output: bool,
) -> Result<(Tensor, Tensor)> {
let _enter = self.span.enter();
let (masks, iou_pred) = self.predict_masks(
image_embeddings,
image_pe,
sparse_prompt_embeddings,
dense_prompt_embeddings,
)?;
let masks = if multimask_output {
masks.i((.., 1..))?
} else {
masks.i((.., 0..1))?
};
let iou_pred = if multimask_output {
iou_pred.i((.., 1..))?
} else {
iou_pred.i((.., 0..1))?
};
Ok((masks, iou_pred))
}
fn predict_masks(
&self,
image_embeddings: &Tensor,
image_pe: &Tensor,
sparse_prompt_embeddings: &Tensor,
dense_prompt_embeddings: &Tensor,
) -> Result<(Tensor, Tensor)> {
// Concatenate ouput tokens.
let output_tokens = Tensor::cat(
&[self.iou_token.embeddings(), self.mask_tokens.embeddings()],
0,
)?;
let (d1, d2) = output_tokens.dims2()?;
let output_tokens =
output_tokens
.unsqueeze(0)?
.expand((sparse_prompt_embeddings.dim(0)?, d1, d2))?;
let tokens = Tensor::cat(&[&output_tokens, sparse_prompt_embeddings], 1)?;
// Expand per-image data in batch direction to be per mask
let src = repeat_interleave(image_embeddings, tokens.dim(0)?, 0)?;
let src = src.broadcast_add(dense_prompt_embeddings)?;
let pos_src = repeat_interleave(image_pe, tokens.dim(0)?, 0)?;
let (b, c, h, w) = src.dims4()?;
// Run the transformer
let (hs, src) = self.transformer.forward(&src, &pos_src, &tokens)?;
let iou_token_out = hs.i((.., 0))?;
let mask_tokens_out = hs.i((.., 1..1 + self.num_mask_tokens))?;
// Upscale mask embeddings and predict masks using the masks tokens.
let src = src.transpose(1, 2)?.reshape((b, c, h, w))?;
let upscaled_embedding = self
.output_upscaling_conv1
.forward(&src)?
.apply(&self.output_upscaling_ln)?
.gelu()?
.apply(&self.output_upscaling_conv2)?
.gelu()?;
let mut hyper_in_list = Vec::with_capacity(self.num_mask_tokens);
for (i, mlp) in self.output_hypernetworks_mlps.iter().enumerate() {
let h = mlp.forward(&mask_tokens_out.i((.., i))?)?;
hyper_in_list.push(h)
}
let hyper_in = Tensor::stack(hyper_in_list.as_slice(), 1)?.contiguous()?;
let (b, c, h, w) = upscaled_embedding.dims4()?;
let masks = hyper_in.matmul(&upscaled_embedding.reshape((b, c, h * w))?)?;
let masks = masks.reshape((b, (), h, w))?;
// Generate mask quality predictions.
let iou_pred = self.iou_prediction_head.forward(&iou_token_out)?;
Ok((masks, iou_pred))
}
}
// Equivalent to torch.repeat_interleave
fn repeat_interleave(img: &Tensor, repeats: usize, dim: usize) -> Result<Tensor> {
let img = img.unsqueeze(dim + 1)?;
let mut dims = img.dims().to_vec();
dims[dim + 1] = repeats;
img.broadcast_as(dims)?.flatten(dim, dim + 1)
}

View File

@ -1,239 +0,0 @@
use candle::{DType, IndexOp, Result, Tensor, D};
use candle_nn::VarBuilder;
#[derive(Debug)]
struct PostionEmbeddingRandom {
positional_encoding_gaussian_matrix: Tensor,
}
impl PostionEmbeddingRandom {
fn new(num_pos_feats: usize, vb: VarBuilder) -> Result<Self> {
let positional_encoding_gaussian_matrix =
vb.get((2, num_pos_feats), "positional_encoding_gaussian_matrix")?;
Ok(Self {
positional_encoding_gaussian_matrix,
})
}
fn pe_encoding(&self, coords: &Tensor) -> Result<Tensor> {
let coords = coords.affine(2., -1.)?;
let coords = coords.broadcast_matmul(&self.positional_encoding_gaussian_matrix)?;
let coords = (coords * (2. * std::f64::consts::PI))?;
Tensor::cat(&[coords.sin()?, coords.cos()?], D::Minus1)
}
fn forward(&self, h: usize, w: usize) -> Result<Tensor> {
let device = self.positional_encoding_gaussian_matrix.device();
let x_embed = (Tensor::arange(0u32, w as u32, device)?.to_dtype(DType::F32)? + 0.5)?;
let y_embed = (Tensor::arange(0u32, h as u32, device)?.to_dtype(DType::F32)? + 0.5)?;
let x_embed = (x_embed / w as f64)?
.reshape((1, ()))?
.broadcast_as((h, w))?;
let y_embed = (y_embed / h as f64)?
.reshape(((), 1))?
.broadcast_as((h, w))?;
let coords = Tensor::stack(&[&x_embed, &y_embed], D::Minus1)?;
self.pe_encoding(&coords)?.permute((2, 0, 1))
}
fn forward_with_coords(
&self,
coords_input: &Tensor,
image_size: (usize, usize),
) -> Result<Tensor> {
let coords0 = (coords_input.narrow(D::Minus1, 0, 1)? / image_size.1 as f64)?;
let coords1 = (coords_input.narrow(D::Minus1, 1, 1)? / image_size.0 as f64)?;
let c = coords_input.dim(D::Minus1)?;
let coords_rest = coords_input.narrow(D::Minus1, 2, c - 2)?;
let coords = Tensor::cat(&[&coords0, &coords1, &coords_rest], D::Minus1)?;
self.pe_encoding(&coords)
}
}
#[derive(Debug)]
pub struct PromptEncoder {
pe_layer: PostionEmbeddingRandom,
point_embeddings: Vec<candle_nn::Embedding>,
not_a_point_embed: candle_nn::Embedding,
mask_downscaling_conv1: candle_nn::Conv2d,
mask_downscaling_ln1: crate::LayerNorm2d,
mask_downscaling_conv2: candle_nn::Conv2d,
mask_downscaling_ln2: crate::LayerNorm2d,
mask_downscaling_conv3: candle_nn::Conv2d,
no_mask_embed: candle_nn::Embedding,
image_embedding_size: (usize, usize),
input_image_size: (usize, usize),
embed_dim: usize,
span: tracing::Span,
}
impl PromptEncoder {
pub fn new(
embed_dim: usize,
image_embedding_size: (usize, usize),
input_image_size: (usize, usize),
mask_in_chans: usize,
vb: VarBuilder,
) -> Result<Self> {
let num_points_embeddings = 4;
let pe_layer = PostionEmbeddingRandom::new(embed_dim / 2, vb.pp("pe_layer"))?;
let not_a_point_embed = candle_nn::embedding(1, embed_dim, vb.pp("not_a_point_embed"))?;
let no_mask_embed = candle_nn::embedding(1, embed_dim, vb.pp("no_mask_embed"))?;
let cfg = candle_nn::Conv2dConfig {
stride: 2,
..Default::default()
};
let mask_downscaling_conv1 =
candle_nn::conv2d(1, mask_in_chans / 4, 2, cfg, vb.pp("mask_downscaling.0"))?;
let mask_downscaling_conv2 = candle_nn::conv2d(
mask_in_chans / 4,
mask_in_chans,
2,
cfg,
vb.pp("mask_downscaling.3"),
)?;
let mask_downscaling_conv3 = candle_nn::conv2d(
mask_in_chans,
embed_dim,
1,
Default::default(),
vb.pp("mask_downscaling.6"),
)?;
let mask_downscaling_ln1 =
crate::LayerNorm2d::new(mask_in_chans / 4, 1e-6, vb.pp("mask_downscaling.1"))?;
let mask_downscaling_ln2 =
crate::LayerNorm2d::new(mask_in_chans, 1e-6, vb.pp("mask_downscaling.4"))?;
let mut point_embeddings = Vec::with_capacity(num_points_embeddings);
let vb_e = vb.pp("point_embeddings");
for i in 0..num_points_embeddings {
let emb = candle_nn::embedding(1, embed_dim, vb_e.pp(i))?;
point_embeddings.push(emb)
}
let span = tracing::span!(tracing::Level::TRACE, "prompt-encoder");
Ok(Self {
pe_layer,
point_embeddings,
not_a_point_embed,
mask_downscaling_conv1,
mask_downscaling_ln1,
mask_downscaling_conv2,
mask_downscaling_ln2,
mask_downscaling_conv3,
no_mask_embed,
image_embedding_size,
input_image_size,
embed_dim,
span,
})
}
pub fn get_dense_pe(&self) -> Result<Tensor> {
self.pe_layer
.forward(self.image_embedding_size.0, self.image_embedding_size.1)?
.unsqueeze(0)
}
fn embed_masks(&self, masks: &Tensor) -> Result<Tensor> {
masks
.apply(&self.mask_downscaling_conv1)?
.apply(&self.mask_downscaling_ln1)?
.gelu()?
.apply(&self.mask_downscaling_conv2)?
.apply(&self.mask_downscaling_ln2)?
.gelu()?
.apply(&self.mask_downscaling_conv3)
}
fn embed_points(&self, points: &Tensor, labels: &Tensor, pad: bool) -> Result<Tensor> {
let points = (points + 0.5)?;
let dev = points.device();
let (points, labels) = if pad {
let padding_point = Tensor::zeros((points.dim(0)?, 1, 2), DType::F32, dev)?;
let padding_label = (Tensor::ones((labels.dim(0)?, 1), DType::F32, dev)? * (-1f64))?;
let points = Tensor::cat(&[&points, &padding_point], 1)?;
let labels = Tensor::cat(&[labels, &padding_label], 1)?;
(points, labels)
} else {
(points, labels.clone())
};
let point_embedding = self
.pe_layer
.forward_with_coords(&points, self.input_image_size)?;
let labels = labels.unsqueeze(2)?.broadcast_as(point_embedding.shape())?;
let zeros = point_embedding.zeros_like()?;
let point_embedding = labels.lt(0f32)?.where_cond(
&self
.not_a_point_embed
.embeddings()
.broadcast_as(zeros.shape())?,
&point_embedding,
)?;
let labels0 = labels.eq(0f32)?.where_cond(
&self.point_embeddings[0]
.embeddings()
.broadcast_as(zeros.shape())?,
&zeros,
)?;
let point_embedding = (point_embedding + labels0)?;
let labels1 = labels.eq(1f32)?.where_cond(
&self.point_embeddings[1]
.embeddings()
.broadcast_as(zeros.shape())?,
&zeros,
)?;
let point_embedding = (point_embedding + labels1)?;
Ok(point_embedding)
}
fn embed_boxes(&self, boxes: &Tensor) -> Result<Tensor> {
let boxes = (boxes + 0.5)?;
let coords = boxes.reshape(((), 2, 2))?;
let corner_embedding = self
.pe_layer
.forward_with_coords(&coords, self.input_image_size)?;
let ce1 = corner_embedding.i((.., 0))?;
let ce2 = corner_embedding.i((.., 1))?;
let ce1 = (ce1 + self.point_embeddings[2].embeddings())?;
let ce2 = (ce2 + self.point_embeddings[3].embeddings())?;
Tensor::cat(&[&ce1, &ce2], 1)
}
pub fn forward(
&self,
points: Option<(&Tensor, &Tensor)>,
boxes: Option<&Tensor>,
masks: Option<&Tensor>,
) -> Result<(Tensor, Tensor)> {
let _enter = self.span.enter();
let se_points = match points {
Some((coords, labels)) => Some(self.embed_points(coords, labels, boxes.is_none())?),
None => None,
};
let se_boxes = match boxes {
Some(boxes) => Some(self.embed_boxes(boxes)?),
None => None,
};
let sparse_embeddings = match (se_points, se_boxes) {
(Some(se_points), Some(se_boxes)) => Tensor::cat(&[se_points, se_boxes], 1)?,
(Some(se_points), None) => se_points,
(None, Some(se_boxes)) => se_boxes,
(None, None) => {
Tensor::zeros((1, 0, self.embed_dim), DType::F32, &candle::Device::Cpu)?
}
};
let dense_embeddings = match masks {
None => {
let emb = self.no_mask_embed.embeddings();
emb.reshape((1, (), 1, 1))?.expand((
1,
emb.elem_count(),
self.image_embedding_size.0,
self.image_embedding_size.1,
))?
}
Some(masks) => self.embed_masks(masks)?,
};
Ok((sparse_embeddings, dense_embeddings))
}
}

View File

@ -1,411 +0,0 @@
use candle::{DType, IndexOp, Result, Tensor};
use candle_nn::{Module, VarBuilder};
use crate::model_image_encoder::ImageEncoderViT;
use crate::model_mask_decoder::MaskDecoder;
use crate::model_prompt_encoder::PromptEncoder;
use crate::model_tiny_vit::{tiny_vit_5m, TinyViT};
const PROMPT_EMBED_DIM: usize = 256;
pub const IMAGE_SIZE: usize = 1024;
const VIT_PATCH_SIZE: usize = 16;
const PRED_IOU_THRESH: f32 = 0.88;
const STABILITY_SCORE_OFFSET: f32 = 1.0;
const STABILITY_SCORE_THRESHOLD: f32 = 0.95;
const MODEL_MASK_THRESHOLD: f32 = 0.0;
const CROP_NMS_THRESH: f32 = 0.7;
#[derive(Debug)]
enum ImageEncoder {
Original(ImageEncoderViT),
TinyViT(TinyViT),
}
impl Module for ImageEncoder {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
match self {
Self::Original(vit) => vit.forward(xs),
Self::TinyViT(vit) => vit.forward(xs),
}
}
}
#[derive(Debug)]
pub struct Sam {
image_encoder: ImageEncoder,
prompt_encoder: PromptEncoder,
mask_decoder: MaskDecoder,
pixel_mean: Tensor,
pixel_std: Tensor,
}
impl Sam {
pub fn new(
encoder_embed_dim: usize,
encoder_depth: usize,
encoder_num_heads: usize,
encoder_global_attn_indexes: &[usize],
vb: VarBuilder,
) -> Result<Self> {
let image_embedding_size = IMAGE_SIZE / VIT_PATCH_SIZE;
let image_encoder = ImageEncoderViT::new(
IMAGE_SIZE,
VIT_PATCH_SIZE,
3,
encoder_embed_dim,
encoder_depth,
encoder_num_heads,
PROMPT_EMBED_DIM,
/* qkv_bias */ true,
/* use_rel_pos */ true,
/* use_abs_pos */ true,
/* window_size */ 14,
/* global_attn_indexes */ encoder_global_attn_indexes,
vb.pp("image_encoder"),
)?;
let prompt_encoder = PromptEncoder::new(
PROMPT_EMBED_DIM,
(image_embedding_size, image_embedding_size),
(IMAGE_SIZE, IMAGE_SIZE),
16,
vb.pp("prompt_encoder"),
)?;
let mask_decoder = MaskDecoder::new(
PROMPT_EMBED_DIM,
/* num_multitask_outputs */ 3,
/* iou_head_depth */ 3,
/* iou_head_hidden_dim */ 256,
vb.pp("mask_decoder"),
)?;
let pixel_mean =
Tensor::new(&[123.675f32, 116.28, 103.53], vb.device())?.reshape((3, 1, 1))?;
let pixel_std =
Tensor::new(&[58.395f32, 57.12, 57.375], vb.device())?.reshape((3, 1, 1))?;
Ok(Self {
image_encoder: ImageEncoder::Original(image_encoder),
prompt_encoder,
mask_decoder,
pixel_std,
pixel_mean,
})
}
pub fn new_tiny(vb: VarBuilder) -> Result<Self> {
let image_embedding_size = IMAGE_SIZE / VIT_PATCH_SIZE;
let image_encoder = tiny_vit_5m(vb.pp("image_encoder"))?;
let prompt_encoder = PromptEncoder::new(
PROMPT_EMBED_DIM,
(image_embedding_size, image_embedding_size),
(IMAGE_SIZE, IMAGE_SIZE),
16,
vb.pp("prompt_encoder"),
)?;
let mask_decoder = MaskDecoder::new(
PROMPT_EMBED_DIM,
/* num_multitask_outputs */ 3,
/* iou_head_depth */ 3,
/* iou_head_hidden_dim */ 256,
vb.pp("mask_decoder"),
)?;
let pixel_mean =
Tensor::new(&[123.675f32, 116.28, 103.53], vb.device())?.reshape((3, 1, 1))?;
let pixel_std =
Tensor::new(&[58.395f32, 57.12, 57.375], vb.device())?.reshape((3, 1, 1))?;
Ok(Self {
image_encoder: ImageEncoder::TinyViT(image_encoder),
prompt_encoder,
mask_decoder,
pixel_std,
pixel_mean,
})
}
pub fn forward(
&self,
img: &Tensor,
point: Option<(f64, f64)>,
multimask_output: bool,
) -> Result<(Tensor, Tensor)> {
let (_c, original_h, original_w) = img.dims3()?;
let img = self.preprocess(img)?.unsqueeze(0)?;
let img_embeddings = self.image_encoder.forward(&img)?;
let image_pe = self.prompt_encoder.get_dense_pe()?;
let points = match point {
None => None,
Some((x, y)) => {
let points = Tensor::new(
&[[[x as f32 * original_w as f32, y as f32 * original_h as f32]]],
img.device(),
)?;
let labels = Tensor::ones((1, 1), DType::F32, img.device())?;
Some((points, labels))
}
};
let points = points.as_ref().map(|(x, y)| (x, y));
let (sparse_prompt_embeddings, dense_prompt_embeddings) =
self.prompt_encoder.forward(points, None, None)?;
let (low_res_mask, iou_predictions) = self.mask_decoder.forward(
&img_embeddings,
&image_pe,
&sparse_prompt_embeddings,
&dense_prompt_embeddings,
multimask_output,
)?;
let mask = low_res_mask
.upsample_nearest2d(IMAGE_SIZE, IMAGE_SIZE)?
.get(0)?
.i((.., ..original_h, ..original_w))?;
Ok((mask, iou_predictions))
}
pub fn unpreprocess(&self, img: &Tensor) -> Result<Tensor> {
let img = img
.broadcast_mul(&self.pixel_std)?
.broadcast_add(&self.pixel_mean)?;
img.maximum(&img.zeros_like()?)?
.minimum(&(img.ones_like()? * 255.)?)
}
pub fn preprocess(&self, img: &Tensor) -> Result<Tensor> {
let (_c, h, w) = img.dims3()?;
let img = img
.to_dtype(DType::F32)?
.broadcast_sub(&self.pixel_mean)?
.broadcast_div(&self.pixel_std)?;
if h > IMAGE_SIZE || w > IMAGE_SIZE {
candle::bail!("image is too large ({w}, {h}), maximum size {IMAGE_SIZE}")
}
let img = img.pad_with_zeros(1, 0, IMAGE_SIZE - h)?;
img.pad_with_zeros(2, 0, IMAGE_SIZE - w)
}
fn process_crop(
&self,
img: &Tensor,
cb: CropBox,
point_grids: &[(f64, f64)],
) -> Result<Vec<candle_examples::object_detection::Bbox<Tensor>>> {
// Crop the image and calculate embeddings.
let img = img.i((.., cb.y0..cb.y1, cb.x0..cb.x1))?;
let img = self.preprocess(&img)?.unsqueeze(0)?;
let img_embeddings = self.image_encoder.forward(&img)?;
let crop_w = cb.x1 - cb.x0;
let crop_h = cb.y1 - cb.y0;
// Generate masks for this crop.
let image_pe = self.prompt_encoder.get_dense_pe()?;
let points = point_grids
.iter()
.map(|&(x, y)| vec![x as f32 * crop_w as f32, y as f32 * crop_h as f32])
.collect::<Vec<_>>();
let mut bboxes = Vec::new();
for points in points.chunks(64) {
// Run the model on this batch.
let points_len = points.len();
let in_points = Tensor::new(points.to_vec(), img.device())?.unsqueeze(1)?;
let in_labels = Tensor::ones((points_len, 1), DType::F32, img.device())?;
let (sparse_prompt_embeddings, dense_prompt_embeddings) =
self.prompt_encoder
.forward(Some((&in_points, &in_labels)), None, None)?;
let (low_res_mask, iou_predictions) = self.mask_decoder.forward(
&img_embeddings,
&image_pe,
&sparse_prompt_embeddings,
&dense_prompt_embeddings,
/* multimask_output */ true,
)?;
let low_res_mask = low_res_mask.flatten(0, 1)?;
let iou_predictions = iou_predictions.flatten(0, 1)?.to_vec1::<f32>()?;
let dev = low_res_mask.device();
for (i, iou) in iou_predictions.iter().enumerate() {
// Filter by predicted IoU.
if *iou < PRED_IOU_THRESH {
continue;
}
let low_res_mask = low_res_mask.get(i)?;
// Calculate stability score.
let bound = Tensor::new(MODEL_MASK_THRESHOLD + STABILITY_SCORE_OFFSET, dev)?
.broadcast_as(low_res_mask.shape())?;
let intersections = low_res_mask
.ge(&bound)?
.to_dtype(DType::F32)?
.sum_all()?
.to_vec0::<f32>()?;
let bound = Tensor::new(MODEL_MASK_THRESHOLD - STABILITY_SCORE_OFFSET, dev)?
.broadcast_as(low_res_mask.shape())?;
let unions = low_res_mask
.ge(&bound)?
.to_dtype(DType::F32)?
.sum_all()?
.to_vec0::<f32>()?;
let stability_score = intersections / unions;
if stability_score < STABILITY_SCORE_THRESHOLD {
continue;
}
// Threshold masks and calculate boxes.
let low_res_mask = low_res_mask
.ge(&Tensor::new(0f32, dev)?.broadcast_as(low_res_mask.shape())?)?
.to_dtype(DType::U32)?;
let low_res_mask_per_x = low_res_mask.sum(0)?.to_vec1::<u32>()?;
let low_res_mask_per_y = low_res_mask.sum(1)?.to_vec1::<u32>()?;
let min_max_x = min_max_indexes(&low_res_mask_per_x);
let min_max_y = min_max_indexes(&low_res_mask_per_y);
if let Some(((x0, x1), (y0, y1))) = min_max_x.zip(min_max_y) {
let bbox = candle_examples::object_detection::Bbox {
xmin: x0 as f32,
ymin: y0 as f32,
xmax: x1 as f32,
ymax: y1 as f32,
confidence: *iou,
data: low_res_mask,
};
bboxes.push(bbox);
}
// TODO:
// Filter boxes that touch crop boundaries
// Compress to RLE.
}
}
let mut bboxes = vec![bboxes];
// Remove duplicates within this crop.
candle_examples::object_detection::non_maximum_suppression(&mut bboxes, CROP_NMS_THRESH);
// TODO: Return to the original image frame.
Ok(bboxes.remove(0))
}
pub fn generate_masks(
&self,
img: &Tensor,
points_per_side: usize,
crop_n_layer: usize,
crop_overlap_ratio: f64,
crop_n_points_downscale_factor: usize,
) -> Result<Vec<candle_examples::object_detection::Bbox<Tensor>>> {
let (_c, h, w) = img.dims3()?;
let point_grids = build_all_layer_point_grids(
points_per_side,
crop_n_layer,
crop_n_points_downscale_factor,
);
let crop_boxes = generate_crop_boxes((h, w), crop_n_layer, crop_overlap_ratio);
let mut bboxes = Vec::new();
for crop_box in crop_boxes.into_iter() {
let layer_idx = crop_box.layer_idx;
let b = self.process_crop(img, crop_box, &point_grids[layer_idx])?;
bboxes.extend(b)
}
// TODO: remove duplicates
Ok(bboxes)
}
}
// Return the first and last indexes i for which values[i] > 0
fn min_max_indexes(values: &[u32]) -> Option<(usize, usize)> {
let (mut min_i, mut max_i) = (usize::MAX, usize::MIN);
for (i, &s) in values.iter().enumerate() {
if s == 0 {
continue;
}
min_i = usize::min(i, min_i);
max_i = usize::max(i, max_i);
}
if max_i < min_i {
None
} else {
Some((min_i, max_i))
}
}
#[derive(Debug)]
struct CropBox {
x0: usize,
y0: usize,
x1: usize,
y1: usize,
layer_idx: usize,
}
impl CropBox {
fn new(x0: usize, y0: usize, x1: usize, y1: usize, layer_idx: usize) -> Self {
Self {
x0,
y0,
x1,
y1,
layer_idx,
}
}
}
fn generate_crop_boxes(
(im_h, im_w): (usize, usize),
n_layers: usize,
overlap_ratio: f64,
) -> Vec<CropBox> {
fn crop_len(orig_len: usize, n_crops: usize, overlap: usize) -> usize {
f64::ceil((overlap * (n_crops - 1) + orig_len) as f64 / n_crops as f64) as usize
}
let short_side = usize::min(im_h, im_w);
let mut crop_boxes = Vec::new();
// Original image.
crop_boxes.push(CropBox::new(0, 0, im_w, im_h, 0));
for layer_idx in 1..=n_layers {
let n_crops_per_side = 1 << layer_idx;
let overlap = (overlap_ratio * short_side as f64 * 2. / n_crops_per_side as f64) as usize;
let crop_w = crop_len(im_w, n_crops_per_side, overlap);
let crop_h = crop_len(im_w, n_crops_per_side, overlap);
for i_x in 0..n_crops_per_side {
let x0 = (crop_w - overlap) * i_x;
for i_y in 0..n_crops_per_side {
let y0 = (crop_h - overlap) * i_y;
let x1 = usize::min(im_w, x0 + crop_w);
let y1 = usize::min(im_h, y0 + crop_h);
crop_boxes.push(CropBox::new(x0, y0, x1, y1, layer_idx));
}
}
}
crop_boxes
}
// Generates a 2D grid of points evenly spaced in [0,1]x[0,1].
fn build_point_grid(n_per_side: usize) -> Vec<(f64, f64)> {
let offset = 1f64 / (2 * n_per_side) as f64;
let mut points = Vec::with_capacity(n_per_side * n_per_side);
for i_x in 0..n_per_side {
let x = offset + i_x as f64 / n_per_side as f64;
for i_y in 0..n_per_side {
let y = offset + i_y as f64 / n_per_side as f64;
points.push((x, y))
}
}
points
}
fn build_all_layer_point_grids(
n_per_side: usize,
n_layers: usize,
scale_per_layer: usize,
) -> Vec<Vec<(f64, f64)>> {
let mut points_by_layer = Vec::with_capacity(n_layers + 1);
for i in 0..=n_layers {
let n_points = n_per_side / scale_per_layer.pow(i as u32);
points_by_layer.push(build_point_grid(n_points))
}
points_by_layer
}

View File

@ -1,633 +0,0 @@
// Adapted from:
// https://github.com/ChaoningZhang/MobileSAM/blob/master/mobile_sam/modeling/tiny_vit_sam.py
use candle::{IndexOp, Result, Tensor, D};
use candle_nn::{Conv2dConfig, Module, VarBuilder};
const MBCONV_EXPAND_RATIO: usize = 4;
const MLP_RATIO: usize = 4;
const LOCAL_CONV_SIZE: usize = 3;
const IMG_SIZE: usize = 1024;
const IN_CHANNELS: usize = 3;
#[derive(Debug)]
struct Conv2dBN {
c: candle_nn::Conv2d,
bn: candle_nn::BatchNorm,
span: tracing::Span,
}
impl Conv2dBN {
fn new(in_: usize, out: usize, ks: usize, cfg: Conv2dConfig, vb: VarBuilder) -> Result<Self> {
let c = candle_nn::conv2d_no_bias(in_, out, ks, cfg, vb.pp("c"))?;
let bn = candle_nn::batch_norm(out, 1e-5, vb.pp("bn"))?;
let span = tracing::span!(tracing::Level::TRACE, "conv2d-bn");
Ok(Self { c, bn, span })
}
}
impl Module for Conv2dBN {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
xs.apply(&self.c)?.apply(&self.bn)
}
}
#[derive(Debug)]
struct PatchEmbed {
conv1: Conv2dBN,
conv2: Conv2dBN,
span: tracing::Span,
}
impl PatchEmbed {
fn new(in_chans: usize, embed_dim: usize, vb: VarBuilder) -> Result<Self> {
let cfg = candle_nn::Conv2dConfig {
stride: 2,
padding: 1,
..Default::default()
};
let conv1 = Conv2dBN::new(in_chans, embed_dim / 2, 3, cfg, vb.pp("seq.0"))?;
let conv2 = Conv2dBN::new(embed_dim / 2, embed_dim, 3, cfg, vb.pp("seq.2"))?;
let span = tracing::span!(tracing::Level::TRACE, "patch-embed");
Ok(Self { conv1, conv2, span })
}
}
impl Module for PatchEmbed {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
xs.apply(&self.conv1)?.gelu()?.apply(&self.conv2)
}
}
#[derive(Debug)]
struct MBConv {
conv1: Conv2dBN,
conv2: Conv2dBN,
conv3: Conv2dBN,
span: tracing::Span,
}
impl MBConv {
fn new(in_: usize, out: usize, expand_ratio: usize, vb: VarBuilder) -> Result<Self> {
let hidden = in_ * expand_ratio;
let cfg2 = candle_nn::Conv2dConfig {
padding: 1,
groups: hidden,
..Default::default()
};
let conv1 = Conv2dBN::new(in_, hidden, 1, Default::default(), vb.pp("conv1"))?;
let conv2 = Conv2dBN::new(hidden, hidden, 3, cfg2, vb.pp("conv2"))?;
let conv3 = Conv2dBN::new(hidden, out, 1, Default::default(), vb.pp("conv3"))?;
let span = tracing::span!(tracing::Level::TRACE, "mb-conv");
Ok(Self {
conv1,
conv2,
conv3,
span,
})
}
}
impl Module for MBConv {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let shortcut = xs;
let xs = xs
.apply(&self.conv1)?
.gelu()?
.apply(&self.conv2)?
.gelu()?
.apply(&self.conv3)?;
(xs + shortcut)?.gelu()
}
}
#[derive(Debug)]
struct PatchMerging {
conv1: Conv2dBN,
conv2: Conv2dBN,
conv3: Conv2dBN,
input_resolution: (usize, usize),
span: tracing::Span,
}
impl PatchMerging {
fn new(
input_resolution: (usize, usize),
dim: usize,
out: usize,
vb: VarBuilder,
) -> Result<Self> {
let stride = if [320, 448, 576].contains(&out) { 1 } else { 2 };
let cfg2 = candle_nn::Conv2dConfig {
padding: 1,
stride,
groups: out,
..Default::default()
};
let conv1 = Conv2dBN::new(dim, out, 1, Default::default(), vb.pp("conv1"))?;
let conv2 = Conv2dBN::new(out, out, 3, cfg2, vb.pp("conv2"))?;
let conv3 = Conv2dBN::new(out, out, 1, Default::default(), vb.pp("conv3"))?;
let span = tracing::span!(tracing::Level::TRACE, "patch-merging");
Ok(Self {
conv1,
conv2,
conv3,
input_resolution,
span,
})
}
}
impl Module for PatchMerging {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let xs = if xs.rank() == 3 {
let (h, w) = self.input_resolution;
let b = xs.dim(0)?;
xs.reshape((b, h, w, ()))?.permute((0, 3, 1, 2))?
} else {
xs.clone()
};
xs.apply(&self.conv1)?
.gelu()?
.apply(&self.conv2)?
.gelu()?
.apply(&self.conv3)?
.flatten_from(2)?
.transpose(1, 2)
}
}
#[derive(Debug)]
struct ConvLayer {
blocks: Vec<MBConv>,
downsample: Option<PatchMerging>,
span: tracing::Span,
}
impl ConvLayer {
fn new(
dim: usize,
out: usize,
input_resolution: (usize, usize),
depth: usize,
downsample: bool,
conv_expand_ratio: usize,
vb: VarBuilder,
) -> Result<Self> {
let vb_b = vb.pp("blocks");
let mut blocks = Vec::with_capacity(depth);
for index in 0..depth {
let block = MBConv::new(dim, dim, conv_expand_ratio, vb_b.pp(index))?;
blocks.push(block)
}
let downsample = if downsample {
let downsample = PatchMerging::new(input_resolution, dim, out, vb.pp("downsample"))?;
Some(downsample)
} else {
None
};
let span = tracing::span!(tracing::Level::TRACE, "conv-layer");
Ok(Self {
blocks,
downsample,
span,
})
}
}
impl Module for ConvLayer {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let mut xs = xs.clone();
for block in self.blocks.iter() {
xs = block.forward(&xs)?
}
match &self.downsample {
None => Ok(xs),
Some(downsample) => downsample.forward(&xs),
}
}
}
#[derive(Debug)]
struct Mlp {
norm: candle_nn::LayerNorm,
fc1: crate::Linear,
fc2: crate::Linear,
span: tracing::Span,
}
impl Mlp {
fn new(in_: usize, hidden: usize, vb: VarBuilder) -> Result<Self> {
let norm = candle_nn::layer_norm(in_, 1e-5, vb.pp("norm"))?;
let fc1 = crate::linear(vb.pp("fc1"), in_, hidden, true)?;
let fc2 = crate::linear(vb.pp("fc2"), hidden, in_, true)?;
let span = tracing::span!(tracing::Level::TRACE, "mlp");
Ok(Self {
norm,
fc1,
fc2,
span,
})
}
}
impl Module for Mlp {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
xs.apply(&self.norm)?
.apply(&self.fc1)?
.gelu()?
.apply(&self.fc2)
}
}
#[derive(Debug)]
struct Attention {
norm: candle_nn::LayerNorm,
qkv: crate::Linear,
proj: crate::Linear,
ab: Tensor,
key_dim: usize,
num_heads: usize,
d: usize,
dh: usize,
scale: f64,
span: tracing::Span,
span_matmul: tracing::Span,
span_softmax: tracing::Span,
}
impl Attention {
fn new(
dim: usize,
key_dim: usize,
num_heads: usize,
attn_ratio: usize,
resolution: (usize, usize),
vb: VarBuilder,
) -> Result<Self> {
let d = attn_ratio * key_dim;
let dh = d * num_heads;
let nh_kd = key_dim * num_heads;
let h = dh + nh_kd * 2;
let norm = candle_nn::layer_norm(dim, 1e-5, vb.pp("norm"))?;
let qkv = crate::linear(vb.pp("qkv"), dim, h, true)?;
let proj = crate::linear(vb.pp("proj"), dh, dim, true)?;
let points = (0..resolution.0)
.flat_map(|x| (0..resolution.1).map(move |y| (x as i64, y as i64)))
.collect::<Vec<_>>();
let mut idxs = Vec::with_capacity(points.len() * points.len());
let mut attention_offsets = std::collections::HashMap::new();
for &(x1, y1) in points.iter() {
for &(x2, y2) in points.iter() {
let offset = ((x2 - x1).abs(), (y2 - y1).abs());
let l = attention_offsets.len();
let idx = attention_offsets.entry(offset).or_insert(l);
idxs.push(*idx as u32)
}
}
let attention_biases = vb.get((num_heads, attention_offsets.len()), "attention_biases")?;
let idxs = Tensor::new(idxs, attention_biases.device())?;
let ab =
attention_biases
.index_select(&idxs, 1)?
.reshape(((), points.len(), points.len()))?;
let span = tracing::span!(tracing::Level::TRACE, "attention");
let span_matmul = tracing::span!(tracing::Level::TRACE, "attn-matmul");
let span_softmax = tracing::span!(tracing::Level::TRACE, "attn-sm");
Ok(Self {
norm,
qkv,
proj,
ab,
key_dim,
num_heads,
d,
dh,
scale: 1f64 / (key_dim as f64).sqrt(),
span,
span_matmul,
span_softmax,
})
}
}
impl Module for Attention {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let (b, n, _) = xs.dims3()?;
let xs = xs.apply(&self.norm)?;
let qkv = xs.apply(&self.qkv)?.reshape((b, n, self.num_heads, ()))?;
let q = qkv
.narrow(D::Minus1, 0, self.key_dim)?
.permute((0, 2, 1, 3))?
.contiguous()?;
let k = qkv
.narrow(D::Minus1, self.key_dim, self.key_dim)?
.permute((0, 2, 1, 3))?
.contiguous()?;
let v = qkv
.narrow(D::Minus1, 2 * self.key_dim, self.d)?
.permute((0, 2, 1, 3))?
.contiguous()?;
let attn = {
let _enter = self.span_matmul.enter();
(q.matmul(&k.t()?)? * self.scale)?
};
let attn = attn.broadcast_add(&self.ab)?;
let attn = {
let _enter = self.span_softmax.enter();
candle_nn::ops::softmax_last_dim(&attn)?
};
let attn = {
let _enter = self.span_matmul.enter();
attn.matmul(&v)?
};
attn.transpose(1, 2)?
.reshape((b, n, self.dh))?
.apply(&self.proj)
}
}
#[derive(Debug)]
struct TinyViTBlock {
attn: Attention,
local_conv: Conv2dBN,
mlp: Mlp,
window_size: usize,
input_resolution: (usize, usize),
span: tracing::Span,
}
impl TinyViTBlock {
fn new(
dim: usize,
input_resolution: (usize, usize),
num_heads: usize,
window_size: usize,
vb: VarBuilder,
) -> Result<Self> {
let head_dim = dim / num_heads;
let attn = Attention::new(
dim,
head_dim,
num_heads,
1,
(window_size, window_size),
vb.pp("attn"),
)?;
let mlp = Mlp::new(dim, dim * MLP_RATIO, vb.pp("mlp"))?;
let cfg = candle_nn::Conv2dConfig {
padding: LOCAL_CONV_SIZE / 2,
groups: dim,
..Default::default()
};
let local_conv = Conv2dBN::new(dim, dim, LOCAL_CONV_SIZE, cfg, vb.pp("local_conv"))?;
let span = tracing::span!(tracing::Level::TRACE, "attention");
Ok(Self {
attn,
local_conv,
mlp,
window_size,
input_resolution,
span,
})
}
}
impl Module for TinyViTBlock {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let (h, w) = self.input_resolution;
let (b, l, c) = xs.dims3()?;
let res_x = xs;
let xs = if h == self.window_size && w == self.window_size {
self.attn.forward(xs)?
} else {
let xs = xs.reshape((b, h, w, c))?;
let pad_b = (self.window_size - h % self.window_size) % self.window_size;
let pad_r = (self.window_size - w % self.window_size) % self.window_size;
let xs = if pad_b > 0 {
xs.pad_with_zeros(1, 0, pad_b)?
} else {
xs
};
let xs = if pad_r > 0 {
xs.pad_with_zeros(2, 0, pad_r)?
} else {
xs
};
let (p_h, p_w) = (h + pad_b, w + pad_r);
let n_h = p_h / self.window_size;
let n_w = p_w / self.window_size;
let xs = xs
.reshape((b, n_h, self.window_size, n_w, self.window_size, c))?
.transpose(2, 3)?
.reshape((b * n_h * n_w, self.window_size * self.window_size, c))?;
let xs = self.attn.forward(&xs)?;
let xs = xs
.reshape((b, n_h, n_w, self.window_size, self.window_size, c))?
.transpose(2, 3)?
.reshape((b, p_h, p_w, c))?;
let xs = if pad_r > 0 {
xs.i((.., .., ..w))?.contiguous()?
} else {
xs
};
let xs = if pad_b > 0 {
xs.i((.., ..h, ..))?.contiguous()?
} else {
xs
};
xs.reshape((b, l, c))?
};
let xs = (xs + res_x)?;
let xs = xs
.transpose(1, 2)?
.reshape((b, c, h, w))?
.apply(&self.local_conv)?
.reshape((b, c, l))?
.transpose(1, 2)?;
&xs + self.mlp.forward(&xs)?
}
}
#[derive(Debug)]
struct BasicLayer {
blocks: Vec<TinyViTBlock>,
downsample: Option<PatchMerging>,
span: tracing::Span,
}
impl BasicLayer {
#[allow(clippy::too_many_arguments)]
fn new(
dim: usize,
input_resolution: (usize, usize),
depth: usize,
num_heads: usize,
window_size: usize,
downsample: bool,
out: usize,
vb: VarBuilder,
) -> Result<Self> {
let vb_b = vb.pp("blocks");
let mut blocks = Vec::with_capacity(depth);
for index in 0..depth {
let block = TinyViTBlock::new(
dim,
input_resolution,
num_heads,
window_size,
vb_b.pp(index),
)?;
blocks.push(block)
}
let downsample = if downsample {
let downsample = PatchMerging::new(input_resolution, dim, out, vb.pp("downsample"))?;
Some(downsample)
} else {
None
};
let span = tracing::span!(tracing::Level::TRACE, "basic-layer");
Ok(Self {
blocks,
downsample,
span,
})
}
}
impl Module for BasicLayer {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let mut xs = xs.clone();
for block in self.blocks.iter() {
xs = block.forward(&xs)?
}
match &self.downsample {
None => Ok(xs),
Some(downsample) => downsample.forward(&xs),
}
}
}
#[derive(Debug)]
pub struct TinyViT {
patch_embed: PatchEmbed,
layer0: ConvLayer,
layers: Vec<BasicLayer>,
// norm_head: candle_nn::LayerNorm,
// head: candle_nn::Linear,
neck_conv1: candle_nn::Conv2d,
neck_ln1: crate::LayerNorm2d,
neck_conv2: candle_nn::Conv2d,
neck_ln2: crate::LayerNorm2d,
span: tracing::Span,
span_neck: tracing::Span,
}
impl TinyViT {
pub fn new(
embed_dims: &[usize],
depths: &[usize],
num_heads: &[usize],
window_sizes: &[usize],
_num_classes: usize,
vb: VarBuilder,
) -> Result<Self> {
let patch_embed = PatchEmbed::new(IN_CHANNELS, embed_dims[0], vb.pp("patch_embed"))?;
let patches_resolution = IMG_SIZE / 4;
let vb_l = vb.pp("layers");
let layer0 = ConvLayer::new(
/* dim */ embed_dims[0],
/* out */ embed_dims[1],
/* input_resolution */ (patches_resolution, patches_resolution),
/* depth */ depths[0],
/* downsample */ true,
/* conv_expand_ratio */ MBCONV_EXPAND_RATIO,
vb_l.pp(0),
)?;
let num_layers = embed_dims.len();
let mut layers = Vec::with_capacity(num_layers - 1);
for i_layer in 1..num_layers {
let patches_resolution = patches_resolution / (1 << usize::min(i_layer, 2));
let layer = BasicLayer::new(
/* dim */ embed_dims[i_layer],
/* input_resolution */ (patches_resolution, patches_resolution),
/* depth */ depths[i_layer],
/* num_heads */ num_heads[i_layer],
/* window_size */ window_sizes[i_layer],
/* downsample */ i_layer < num_layers - 1,
/* out */ embed_dims[usize::min(i_layer + 1, num_layers - 1)],
vb_l.pp(i_layer),
)?;
layers.push(layer)
}
let last_embed_dim = embed_dims[embed_dims.len() - 1];
// let norm_head = candle_nn::layer_norm(last_embed_dim, 1e-5, vb.pp("norm_head"))?;
// let head = candle_nn::linear(last_embed_dim, num_classes, vb.pp("head"))?;
let neck_conv1 =
candle_nn::conv2d_no_bias(last_embed_dim, 256, 1, Default::default(), vb.pp("neck.0"))?;
let neck_ln1 = crate::LayerNorm2d::new(256, 1e-6, vb.pp("neck.1"))?;
let cfg = candle_nn::Conv2dConfig {
padding: 1,
..Default::default()
};
let neck_conv2 = candle_nn::conv2d_no_bias(256, 256, 3, cfg, vb.pp("neck.2"))?;
let neck_ln2 = crate::LayerNorm2d::new(256, 1e-6, vb.pp("neck.3"))?;
let span = tracing::span!(tracing::Level::TRACE, "tiny-vit");
let span_neck = tracing::span!(tracing::Level::TRACE, "neck");
Ok(Self {
patch_embed,
layer0,
layers,
neck_conv1,
neck_ln1,
neck_conv2,
neck_ln2,
span,
span_neck,
})
}
}
impl Module for TinyViT {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let xs = self.patch_embed.forward(xs)?;
let mut xs = self.layer0.forward(&xs)?;
for layer in self.layers.iter() {
xs = layer.forward(&xs)?
}
let (b, _, c) = xs.dims3()?;
let _enter = self.span_neck.enter();
xs.reshape((b, 64, 64, c))?
.permute((0, 3, 1, 2))?
.apply(&self.neck_conv1)?
.apply(&self.neck_ln1)?
.apply(&self.neck_conv2)?
.apply(&self.neck_ln2)
}
}
pub fn tiny_vit_5m(vb: VarBuilder) -> Result<TinyViT> {
TinyViT::new(
/* embed_dims */ &[64, 128, 160, 320],
/* depths */ &[2, 2, 6, 2],
/* num_heads */ &[2, 4, 5, 10],
/* window_sizes */ &[7, 7, 14, 7],
/* num_classes */ 1000,
vb,
)
}

View File

@ -1,221 +0,0 @@
use candle::{Result, Tensor};
use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
#[derive(Debug)]
struct Attention {
q_proj: Linear,
k_proj: Linear,
v_proj: Linear,
out_proj: Linear,
num_heads: usize,
}
impl Attention {
fn new(
embedding_dim: usize,
num_heads: usize,
downsample_rate: usize,
vb: VarBuilder,
) -> Result<Self> {
let internal_dim = embedding_dim / downsample_rate;
let q_proj = candle_nn::linear(embedding_dim, internal_dim, vb.pp("q_proj"))?;
let k_proj = candle_nn::linear(embedding_dim, internal_dim, vb.pp("k_proj"))?;
let v_proj = candle_nn::linear(embedding_dim, internal_dim, vb.pp("v_proj"))?;
let out_proj = candle_nn::linear(internal_dim, embedding_dim, vb.pp("out_proj"))?;
Ok(Self {
q_proj,
k_proj,
v_proj,
out_proj,
num_heads,
})
}
fn separate_heads(&self, x: &Tensor) -> Result<Tensor> {
let (b, n, c) = x.dims3()?;
x.reshape((b, n, self.num_heads, c / self.num_heads))?
.transpose(1, 2)?
.contiguous()
}
fn recombine_heads(&self, x: &Tensor) -> Result<Tensor> {
let (b, n_heads, n_tokens, c_per_head) = x.dims4()?;
x.transpose(1, 2)?
.reshape((b, n_tokens, n_heads * c_per_head))
}
fn forward(&self, q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
let q = self.q_proj.forward(&q.contiguous()?)?;
let k = self.k_proj.forward(&k.contiguous()?)?;
let v = self.v_proj.forward(&v.contiguous()?)?;
let q = self.separate_heads(&q)?;
let k = self.separate_heads(&k)?;
let v = self.separate_heads(&v)?;
let (_, _, _, c_per_head) = q.dims4()?;
let attn = (q.matmul(&k.t()?)? / (c_per_head as f64).sqrt())?;
let attn = candle_nn::ops::softmax_last_dim(&attn)?;
let out = attn.matmul(&v)?;
self.recombine_heads(&out)?.apply(&self.out_proj)
}
}
#[derive(Debug)]
struct TwoWayAttentionBlock {
self_attn: Attention,
norm1: LayerNorm,
cross_attn_token_to_image: Attention,
norm2: LayerNorm,
mlp: crate::MlpBlock,
norm3: LayerNorm,
norm4: LayerNorm,
cross_attn_image_to_token: Attention,
skip_first_layer_pe: bool,
}
impl TwoWayAttentionBlock {
fn new(
embedding_dim: usize,
num_heads: usize,
mlp_dim: usize,
skip_first_layer_pe: bool,
vb: VarBuilder,
) -> Result<Self> {
let norm1 = layer_norm(embedding_dim, 1e-5, vb.pp("norm1"))?;
let norm2 = layer_norm(embedding_dim, 1e-5, vb.pp("norm2"))?;
let norm3 = layer_norm(embedding_dim, 1e-5, vb.pp("norm3"))?;
let norm4 = layer_norm(embedding_dim, 1e-5, vb.pp("norm4"))?;
let self_attn = Attention::new(embedding_dim, num_heads, 1, vb.pp("self_attn"))?;
let cross_attn_token_to_image = Attention::new(
embedding_dim,
num_heads,
2,
vb.pp("cross_attn_token_to_image"),
)?;
let cross_attn_image_to_token = Attention::new(
embedding_dim,
num_heads,
2,
vb.pp("cross_attn_image_to_token"),
)?;
let mlp = crate::MlpBlock::new(
embedding_dim,
mlp_dim,
candle_nn::Activation::Relu,
vb.pp("mlp"),
)?;
Ok(Self {
self_attn,
norm1,
cross_attn_image_to_token,
norm2,
mlp,
norm3,
norm4,
cross_attn_token_to_image,
skip_first_layer_pe,
})
}
fn forward(
&self,
queries: &Tensor,
keys: &Tensor,
query_pe: &Tensor,
key_pe: &Tensor,
) -> Result<(Tensor, Tensor)> {
// Self attention block
let queries = if self.skip_first_layer_pe {
self.self_attn.forward(queries, queries, queries)?
} else {
let q = (queries + query_pe)?;
let attn_out = self.self_attn.forward(&q, &q, queries)?;
(queries + attn_out)?
};
let queries = self.norm1.forward(&queries)?;
// Cross attention block, tokens attending to image embedding
let q = (&queries + query_pe)?;
let k = (keys + key_pe)?;
let attn_out = self.cross_attn_token_to_image.forward(&q, &k, keys)?;
let queries = (&queries + attn_out)?;
let queries = self.norm2.forward(&queries)?;
// MLP block
let mlp_out = self.mlp.forward(&queries);
let queries = (queries + mlp_out)?;
let queries = self.norm3.forward(&queries)?;
// Cross attention block, image embedding attending to tokens
let q = (&queries + query_pe)?;
let k = (keys + key_pe)?;
let attn_out = self.cross_attn_image_to_token.forward(&k, &q, &queries)?;
let keys = (keys + attn_out)?;
let keys = self.norm4.forward(&keys)?;
Ok((queries, keys))
}
}
#[derive(Debug)]
pub struct TwoWayTransformer {
layers: Vec<TwoWayAttentionBlock>,
final_attn_token_to_image: Attention,
norm_final_attn: LayerNorm,
}
impl TwoWayTransformer {
pub fn new(
depth: usize,
embedding_dim: usize,
num_heads: usize,
mlp_dim: usize,
vb: VarBuilder,
) -> Result<Self> {
let vb_l = vb.pp("layers");
let mut layers = Vec::with_capacity(depth);
for i in 0..depth {
let layer =
TwoWayAttentionBlock::new(embedding_dim, num_heads, mlp_dim, i == 0, vb_l.pp(i))?;
layers.push(layer)
}
let final_attn_token_to_image = Attention::new(
embedding_dim,
num_heads,
2,
vb.pp("final_attn_token_to_image"),
)?;
let norm_final_attn = layer_norm(embedding_dim, 1e-5, vb.pp("norm_final_attn"))?;
Ok(Self {
layers,
final_attn_token_to_image,
norm_final_attn,
})
}
pub fn forward(
&self,
image_embedding: &Tensor,
image_pe: &Tensor,
point_embedding: &Tensor,
) -> Result<(Tensor, Tensor)> {
let image_embedding = image_embedding.flatten_from(2)?.permute((0, 2, 1))?;
let image_pe = image_pe.flatten_from(2)?.permute((0, 2, 1))?;
let mut queries = point_embedding.clone();
let mut keys = image_embedding;
for layer in self.layers.iter() {
(queries, keys) = layer.forward(&queries, &keys, point_embedding, &image_pe)?
}
let q = (&queries + point_embedding)?;
let k = (&keys + image_pe)?;
let attn_out = self.final_attn_token_to_image.forward(&q, &k, &keys)?;
let queries = (queries + attn_out)?.apply(&self.norm_final_attn)?;
Ok((queries, keys))
}
}

View File

@ -4,7 +4,7 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use candle_examples::object_detection::{non_maximum_suppression, Bbox};
use candle_transformers::object_detection::{non_maximum_suppression, Bbox};
mod darknet;
use anyhow::Result;

View File

@ -8,8 +8,8 @@ mod model;
use model::{Multiples, YoloV8, YoloV8Pose};
use candle::{DType, Device, IndexOp, Result, Tensor};
use candle_examples::object_detection::{non_maximum_suppression, Bbox, KeyPoint};
use candle_nn::{Module, VarBuilder};
use candle_transformers::object_detection::{non_maximum_suppression, Bbox, KeyPoint};
use clap::{Parser, ValueEnum};
use image::DynamicImage;