Get the MobileSAM TinyViT based version to work. (#789)

* More TinyViT support in SA.

* More mobilesam work.

* Add the mobile-sam weights to the hub.
This commit is contained in:
Laurent Mazare
2023-09-09 16:21:44 +01:00
committed by GitHub
parent b7cd58473b
commit 74ad4deb42
3 changed files with 89 additions and 26 deletions

View File

@ -133,6 +133,10 @@ struct Args {
/// Enable tracing (generates a trace-timestamp.json file). /// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)] #[arg(long)]
tracing: bool, tracing: bool,
/// Use the TinyViT based models from MobileSAM
#[arg(long)]
use_tiny: bool,
} }
pub fn main() -> anyhow::Result<()> { pub fn main() -> anyhow::Result<()> {
@ -179,13 +183,22 @@ pub fn main() -> anyhow::Result<()> {
None => { None => {
let api = hf_hub::api::sync::Api::new()?; let api = hf_hub::api::sync::Api::new()?;
let api = api.model("lmz/candle-sam".to_string()); let api = api.model("lmz/candle-sam".to_string());
api.get("sam_vit_b_01ec64.safetensors")? let filename = if args.use_tiny {
"mobile_sam-tiny-vitt.safetensors"
} else {
"sam_vit_b_01ec64.safetensors"
};
api.get(filename)?
} }
}; };
let weights = unsafe { candle::safetensors::MmapedFile::new(model)? }; let weights = unsafe { candle::safetensors::MmapedFile::new(model)? };
let weights = weights.deserialize()?; let weights = weights.deserialize()?;
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device); let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
let sam = model_sam::Sam::new(768, 12, 12, &[2, 5, 8, 11], vb)?; // sam_vit_b let sam = if args.use_tiny {
model_sam::Sam::new_tiny(vb)? // tiny vit_t
} else {
model_sam::Sam::new(768, 12, 12, &[2, 5, 8, 11], vb)? // sam_vit_b
};
if args.generate_masks { if args.generate_masks {
// Default options similar to the Python version. // Default options similar to the Python version.

View File

@ -4,6 +4,7 @@ use candle_nn::{Module, VarBuilder};
use crate::model_image_encoder::ImageEncoderViT; use crate::model_image_encoder::ImageEncoderViT;
use crate::model_mask_decoder::MaskDecoder; use crate::model_mask_decoder::MaskDecoder;
use crate::model_prompt_encoder::PromptEncoder; use crate::model_prompt_encoder::PromptEncoder;
use crate::model_tiny_vit::{tiny_vit_5m, TinyViT};
const PROMPT_EMBED_DIM: usize = 256; const PROMPT_EMBED_DIM: usize = 256;
pub const IMAGE_SIZE: usize = 1024; pub const IMAGE_SIZE: usize = 1024;
@ -14,9 +15,24 @@ const STABILITY_SCORE_THRESHOLD: f32 = 0.95;
const MODEL_MASK_THRESHOLD: f32 = 0.0; const MODEL_MASK_THRESHOLD: f32 = 0.0;
const CROP_NMS_THRESH: f32 = 0.7; const CROP_NMS_THRESH: f32 = 0.7;
#[derive(Debug)]
enum ImageEncoder {
Original(ImageEncoderViT),
TinyViT(TinyViT),
}
impl Module for ImageEncoder {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
match self {
Self::Original(vit) => vit.forward(xs),
Self::TinyViT(vit) => vit.forward(xs),
}
}
}
#[derive(Debug)] #[derive(Debug)]
pub struct Sam { pub struct Sam {
image_encoder: ImageEncoderViT, image_encoder: ImageEncoder,
prompt_encoder: PromptEncoder, prompt_encoder: PromptEncoder,
mask_decoder: MaskDecoder, mask_decoder: MaskDecoder,
pixel_mean: Tensor, pixel_mean: Tensor,
@ -67,7 +83,38 @@ impl Sam {
let pixel_std = let pixel_std =
Tensor::new(&[58.395f32, 57.12, 57.375], vb.device())?.reshape((3, 1, 1))?; Tensor::new(&[58.395f32, 57.12, 57.375], vb.device())?.reshape((3, 1, 1))?;
Ok(Self { Ok(Self {
image_encoder, image_encoder: ImageEncoder::Original(image_encoder),
prompt_encoder,
mask_decoder,
pixel_std,
pixel_mean,
})
}
pub fn new_tiny(vb: VarBuilder) -> Result<Self> {
let image_embedding_size = IMAGE_SIZE / VIT_PATCH_SIZE;
let image_encoder = tiny_vit_5m(vb.pp("image_encoder"))?;
let prompt_encoder = PromptEncoder::new(
PROMPT_EMBED_DIM,
(image_embedding_size, image_embedding_size),
(IMAGE_SIZE, IMAGE_SIZE),
16,
vb.pp("prompt_encoder"),
)?;
let mask_decoder = MaskDecoder::new(
PROMPT_EMBED_DIM,
/* num_multitask_outputs */ 3,
/* iou_head_depth */ 3,
/* iou_head_hidden_dim */ 256,
vb.pp("mask_decoder"),
)?;
let pixel_mean =
Tensor::new(&[123.675f32, 116.28, 103.53], vb.device())?.reshape((3, 1, 1))?;
let pixel_std =
Tensor::new(&[58.395f32, 57.12, 57.375], vb.device())?.reshape((3, 1, 1))?;
Ok(Self {
image_encoder: ImageEncoder::TinyViT(image_encoder),
prompt_encoder, prompt_encoder,
mask_decoder, mask_decoder,
pixel_std, pixel_std,

View File

@ -1,13 +1,12 @@
// Adapted from: // Adapted from:
// https://github.com/ChaoningZhang/MobileSAM/blob/master/mobile_sam/modeling/tiny_vit_sam.py // https://github.com/ChaoningZhang/MobileSAM/blob/master/mobile_sam/modeling/tiny_vit_sam.py
#![allow(unused)] use candle::{IndexOp, Result, Tensor, D};
use candle::{DType, IndexOp, Result, Tensor, D};
use candle_nn::{Conv2dConfig, Module, VarBuilder}; use candle_nn::{Conv2dConfig, Module, VarBuilder};
const MBCONV_EXPAND_RATIO: usize = 4; const MBCONV_EXPAND_RATIO: usize = 4;
const MLP_RATIO: usize = 4; const MLP_RATIO: usize = 4;
const LOCAL_CONV_SIZE: usize = 3; const LOCAL_CONV_SIZE: usize = 3;
const IMG_SIZE: usize = 224; const IMG_SIZE: usize = 1024;
const IN_CHANNELS: usize = 3; const IN_CHANNELS: usize = 3;
#[derive(Debug)] #[derive(Debug)]
@ -18,7 +17,7 @@ struct Conv2dBN {
impl Conv2dBN { impl Conv2dBN {
fn new(in_: usize, out: usize, ks: usize, cfg: Conv2dConfig, vb: VarBuilder) -> Result<Self> { fn new(in_: usize, out: usize, ks: usize, cfg: Conv2dConfig, vb: VarBuilder) -> Result<Self> {
let c = candle_nn::conv2d(in_, out, ks, cfg, vb.pp("c"))?; let c = candle_nn::conv2d_no_bias(in_, out, ks, cfg, vb.pp("c"))?;
let bn = candle_nn::batch_norm(out, 1e-5, vb.pp("bn"))?; let bn = candle_nn::batch_norm(out, 1e-5, vb.pp("bn"))?;
Ok(Self { c, bn }) Ok(Self { c, bn })
} }
@ -222,7 +221,6 @@ struct Attention {
norm: candle_nn::LayerNorm, norm: candle_nn::LayerNorm,
qkv: candle_nn::Linear, qkv: candle_nn::Linear,
proj: candle_nn::Linear, proj: candle_nn::Linear,
attention_biases: Tensor,
ab: Tensor, ab: Tensor,
key_dim: usize, key_dim: usize,
num_heads: usize, num_heads: usize,
@ -263,12 +261,14 @@ impl Attention {
} }
let attention_biases = vb.get((num_heads, attention_offsets.len()), "attention_biases")?; let attention_biases = vb.get((num_heads, attention_offsets.len()), "attention_biases")?;
let idxs = Tensor::new(idxs, attention_biases.device())?; let idxs = Tensor::new(idxs, attention_biases.device())?;
let ab = attention_biases.index_select(&idxs, 1)?; let ab =
attention_biases
.index_select(&idxs, 1)?
.reshape(((), points.len(), points.len()))?;
Ok(Self { Ok(Self {
norm, norm,
qkv, qkv,
proj, proj,
attention_biases,
ab, ab,
key_dim, key_dim,
num_heads, num_heads,
@ -286,15 +286,18 @@ impl Module for Attention {
let qkv = xs.apply(&self.qkv)?.reshape((b, n, self.num_heads, ()))?; let qkv = xs.apply(&self.qkv)?.reshape((b, n, self.num_heads, ()))?;
let q = qkv let q = qkv
.narrow(D::Minus1, 0, self.key_dim)? .narrow(D::Minus1, 0, self.key_dim)?
.permute((0, 2, 1, 3))?; .permute((0, 2, 1, 3))?
.contiguous()?;
let k = qkv let k = qkv
.narrow(D::Minus1, self.key_dim, self.key_dim)? .narrow(D::Minus1, self.key_dim, self.key_dim)?
.permute((0, 2, 1, 3))?; .permute((0, 2, 1, 3))?
.contiguous()?;
let v = qkv let v = qkv
.narrow(D::Minus1, 2 * self.key_dim, self.d)? .narrow(D::Minus1, 2 * self.key_dim, self.d)?
.permute((0, 2, 1, 3))?; .permute((0, 2, 1, 3))?
.contiguous()?;
let attn = (q.matmul(&k.t()?)? * self.scale)?; let attn = (q.matmul(&k.t()?)? * self.scale)?;
let attn = (attn + &self.ab)?; let attn = attn.broadcast_add(&self.ab)?;
let attn = candle_nn::ops::softmax_last_dim(&attn)?; let attn = candle_nn::ops::softmax_last_dim(&attn)?;
attn.matmul(&v)? attn.matmul(&v)?
.transpose(1, 2)? .transpose(1, 2)?
@ -332,6 +335,7 @@ impl TinyViTBlock {
let mlp = Mlp::new(dim, dim * MLP_RATIO, vb.pp("mlp"))?; let mlp = Mlp::new(dim, dim * MLP_RATIO, vb.pp("mlp"))?;
let cfg = candle_nn::Conv2dConfig { let cfg = candle_nn::Conv2dConfig {
padding: LOCAL_CONV_SIZE / 2, padding: LOCAL_CONV_SIZE / 2,
groups: dim,
..Default::default() ..Default::default()
}; };
let local_conv = Conv2dBN::new(dim, dim, LOCAL_CONV_SIZE, cfg, vb.pp("local_conv"))?; let local_conv = Conv2dBN::new(dim, dim, LOCAL_CONV_SIZE, cfg, vb.pp("local_conv"))?;
@ -358,12 +362,12 @@ impl Module for TinyViTBlock {
let pad_r = (self.window_size - w % self.window_size) % self.window_size; let pad_r = (self.window_size - w % self.window_size) % self.window_size;
let xs = if pad_b > 0 { let xs = if pad_b > 0 {
xs.pad_with_zeros(D::Minus2, 0, pad_b)? xs.pad_with_zeros(1, 0, pad_b)?
} else { } else {
xs xs
}; };
let xs = if pad_r > 0 { let xs = if pad_r > 0 {
xs.pad_with_zeros(D::Minus1, 0, pad_r)? xs.pad_with_zeros(2, 0, pad_r)?
} else { } else {
xs xs
}; };
@ -460,8 +464,8 @@ pub struct TinyViT {
patch_embed: PatchEmbed, patch_embed: PatchEmbed,
layer0: ConvLayer, layer0: ConvLayer,
layers: Vec<BasicLayer>, layers: Vec<BasicLayer>,
norm_head: candle_nn::LayerNorm, // norm_head: candle_nn::LayerNorm,
head: candle_nn::Linear, // head: candle_nn::Linear,
neck_conv1: candle_nn::Conv2d, neck_conv1: candle_nn::Conv2d,
neck_ln1: crate::LayerNorm2d, neck_ln1: crate::LayerNorm2d,
neck_conv2: candle_nn::Conv2d, neck_conv2: candle_nn::Conv2d,
@ -474,7 +478,7 @@ impl TinyViT {
depths: &[usize], depths: &[usize],
num_heads: &[usize], num_heads: &[usize],
window_sizes: &[usize], window_sizes: &[usize],
num_classes: usize, _num_classes: usize,
vb: VarBuilder, vb: VarBuilder,
) -> Result<Self> { ) -> Result<Self> {
let patch_embed = PatchEmbed::new(IN_CHANNELS, embed_dims[0], vb.pp("patch_embed"))?; let patch_embed = PatchEmbed::new(IN_CHANNELS, embed_dims[0], vb.pp("patch_embed"))?;
@ -509,8 +513,8 @@ impl TinyViT {
} }
let last_embed_dim = embed_dims[embed_dims.len() - 1]; let last_embed_dim = embed_dims[embed_dims.len() - 1];
let norm_head = candle_nn::layer_norm(last_embed_dim, 1e-5, vb.pp("norm_head"))?; // let norm_head = candle_nn::layer_norm(last_embed_dim, 1e-5, vb.pp("norm_head"))?;
let head = candle_nn::linear(last_embed_dim, num_classes, vb.pp("head"))?; // let head = candle_nn::linear(last_embed_dim, num_classes, vb.pp("head"))?;
let neck_conv1 = let neck_conv1 =
candle_nn::conv2d_no_bias(last_embed_dim, 256, 1, Default::default(), vb.pp("neck.0"))?; candle_nn::conv2d_no_bias(last_embed_dim, 256, 1, Default::default(), vb.pp("neck.0"))?;
let neck_ln1 = crate::LayerNorm2d::new(256, 1e-6, vb.pp("neck.1"))?; let neck_ln1 = crate::LayerNorm2d::new(256, 1e-6, vb.pp("neck.1"))?;
@ -525,8 +529,6 @@ impl TinyViT {
patch_embed, patch_embed,
layer0, layer0,
layers, layers,
norm_head,
head,
neck_conv1, neck_conv1,
neck_ln1, neck_ln1,
neck_conv2, neck_conv2,
@ -537,7 +539,8 @@ impl TinyViT {
impl Module for TinyViT { impl Module for TinyViT {
fn forward(&self, xs: &Tensor) -> Result<Tensor> { fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let mut xs = self.patch_embed.forward(xs)?; let xs = self.patch_embed.forward(xs)?;
let mut xs = self.layer0.forward(&xs)?;
for layer in self.layers.iter() { for layer in self.layers.iter() {
xs = layer.forward(&xs)? xs = layer.forward(&xs)?
} }
@ -551,7 +554,7 @@ impl Module for TinyViT {
} }
} }
pub fn tiny_vit_5m_224(vb: VarBuilder) -> Result<TinyViT> { pub fn tiny_vit_5m(vb: VarBuilder) -> Result<TinyViT> {
TinyViT::new( TinyViT::new(
/* embed_dims */ &[64, 128, 160, 320], /* embed_dims */ &[64, 128, 160, 320],
/* depths */ &[2, 2, 6, 2], /* depths */ &[2, 2, 6, 2],