Get the MobileSAM TinyViT based version to work. (#789)

* More TinyViT support in SA.

* More mobilesam work.

* Add the mobile-sam weights to the hub.
This commit is contained in:
Laurent Mazare
2023-09-09 16:21:44 +01:00
committed by GitHub
parent b7cd58473b
commit 74ad4deb42
3 changed files with 89 additions and 26 deletions

View File

@ -1,13 +1,12 @@
// Adapted from:
// https://github.com/ChaoningZhang/MobileSAM/blob/master/mobile_sam/modeling/tiny_vit_sam.py
#![allow(unused)]
use candle::{DType, IndexOp, Result, Tensor, D};
use candle::{IndexOp, Result, Tensor, D};
use candle_nn::{Conv2dConfig, Module, VarBuilder};
const MBCONV_EXPAND_RATIO: usize = 4;
const MLP_RATIO: usize = 4;
const LOCAL_CONV_SIZE: usize = 3;
const IMG_SIZE: usize = 224;
const IMG_SIZE: usize = 1024;
const IN_CHANNELS: usize = 3;
#[derive(Debug)]
@ -18,7 +17,7 @@ struct Conv2dBN {
impl Conv2dBN {
fn new(in_: usize, out: usize, ks: usize, cfg: Conv2dConfig, vb: VarBuilder) -> Result<Self> {
let c = candle_nn::conv2d(in_, out, ks, cfg, vb.pp("c"))?;
let c = candle_nn::conv2d_no_bias(in_, out, ks, cfg, vb.pp("c"))?;
let bn = candle_nn::batch_norm(out, 1e-5, vb.pp("bn"))?;
Ok(Self { c, bn })
}
@ -222,7 +221,6 @@ struct Attention {
norm: candle_nn::LayerNorm,
qkv: candle_nn::Linear,
proj: candle_nn::Linear,
attention_biases: Tensor,
ab: Tensor,
key_dim: usize,
num_heads: usize,
@ -263,12 +261,14 @@ impl Attention {
}
let attention_biases = vb.get((num_heads, attention_offsets.len()), "attention_biases")?;
let idxs = Tensor::new(idxs, attention_biases.device())?;
let ab = attention_biases.index_select(&idxs, 1)?;
let ab =
attention_biases
.index_select(&idxs, 1)?
.reshape(((), points.len(), points.len()))?;
Ok(Self {
norm,
qkv,
proj,
attention_biases,
ab,
key_dim,
num_heads,
@ -286,15 +286,18 @@ impl Module for Attention {
let qkv = xs.apply(&self.qkv)?.reshape((b, n, self.num_heads, ()))?;
let q = qkv
.narrow(D::Minus1, 0, self.key_dim)?
.permute((0, 2, 1, 3))?;
.permute((0, 2, 1, 3))?
.contiguous()?;
let k = qkv
.narrow(D::Minus1, self.key_dim, self.key_dim)?
.permute((0, 2, 1, 3))?;
.permute((0, 2, 1, 3))?
.contiguous()?;
let v = qkv
.narrow(D::Minus1, 2 * self.key_dim, self.d)?
.permute((0, 2, 1, 3))?;
.permute((0, 2, 1, 3))?
.contiguous()?;
let attn = (q.matmul(&k.t()?)? * self.scale)?;
let attn = (attn + &self.ab)?;
let attn = attn.broadcast_add(&self.ab)?;
let attn = candle_nn::ops::softmax_last_dim(&attn)?;
attn.matmul(&v)?
.transpose(1, 2)?
@ -332,6 +335,7 @@ impl TinyViTBlock {
let mlp = Mlp::new(dim, dim * MLP_RATIO, vb.pp("mlp"))?;
let cfg = candle_nn::Conv2dConfig {
padding: LOCAL_CONV_SIZE / 2,
groups: dim,
..Default::default()
};
let local_conv = Conv2dBN::new(dim, dim, LOCAL_CONV_SIZE, cfg, vb.pp("local_conv"))?;
@ -358,12 +362,12 @@ impl Module for TinyViTBlock {
let pad_r = (self.window_size - w % self.window_size) % self.window_size;
let xs = if pad_b > 0 {
xs.pad_with_zeros(D::Minus2, 0, pad_b)?
xs.pad_with_zeros(1, 0, pad_b)?
} else {
xs
};
let xs = if pad_r > 0 {
xs.pad_with_zeros(D::Minus1, 0, pad_r)?
xs.pad_with_zeros(2, 0, pad_r)?
} else {
xs
};
@ -460,8 +464,8 @@ pub struct TinyViT {
patch_embed: PatchEmbed,
layer0: ConvLayer,
layers: Vec<BasicLayer>,
norm_head: candle_nn::LayerNorm,
head: candle_nn::Linear,
// norm_head: candle_nn::LayerNorm,
// head: candle_nn::Linear,
neck_conv1: candle_nn::Conv2d,
neck_ln1: crate::LayerNorm2d,
neck_conv2: candle_nn::Conv2d,
@ -474,7 +478,7 @@ impl TinyViT {
depths: &[usize],
num_heads: &[usize],
window_sizes: &[usize],
num_classes: usize,
_num_classes: usize,
vb: VarBuilder,
) -> Result<Self> {
let patch_embed = PatchEmbed::new(IN_CHANNELS, embed_dims[0], vb.pp("patch_embed"))?;
@ -509,8 +513,8 @@ impl TinyViT {
}
let last_embed_dim = embed_dims[embed_dims.len() - 1];
let norm_head = candle_nn::layer_norm(last_embed_dim, 1e-5, vb.pp("norm_head"))?;
let head = candle_nn::linear(last_embed_dim, num_classes, vb.pp("head"))?;
// let norm_head = candle_nn::layer_norm(last_embed_dim, 1e-5, vb.pp("norm_head"))?;
// let head = candle_nn::linear(last_embed_dim, num_classes, vb.pp("head"))?;
let neck_conv1 =
candle_nn::conv2d_no_bias(last_embed_dim, 256, 1, Default::default(), vb.pp("neck.0"))?;
let neck_ln1 = crate::LayerNorm2d::new(256, 1e-6, vb.pp("neck.1"))?;
@ -525,8 +529,6 @@ impl TinyViT {
patch_embed,
layer0,
layers,
norm_head,
head,
neck_conv1,
neck_ln1,
neck_conv2,
@ -537,7 +539,8 @@ impl TinyViT {
impl Module for TinyViT {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let mut xs = self.patch_embed.forward(xs)?;
let xs = self.patch_embed.forward(xs)?;
let mut xs = self.layer0.forward(&xs)?;
for layer in self.layers.iter() {
xs = layer.forward(&xs)?
}
@ -551,7 +554,7 @@ impl Module for TinyViT {
}
}
pub fn tiny_vit_5m_224(vb: VarBuilder) -> Result<TinyViT> {
pub fn tiny_vit_5m(vb: VarBuilder) -> Result<TinyViT> {
TinyViT::new(
/* embed_dims */ &[64, 128, 160, 320],
/* depths */ &[2, 2, 6, 2],