From 74ad4deb42da71d4c47220e9595f58445f6f7298 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sat, 9 Sep 2023 16:21:44 +0100 Subject: [PATCH] Get the MobileSAM TinyViT based version to work. (#789) * More TinyViT support in SA. * More mobilesam work. * Add the mobile-sam weights to the hub. --- .../examples/segment-anything/main.rs | 17 ++++++- .../examples/segment-anything/model_sam.rs | 51 ++++++++++++++++++- .../segment-anything/model_tiny_vit.rs | 47 +++++++++-------- 3 files changed, 89 insertions(+), 26 deletions(-) diff --git a/candle-examples/examples/segment-anything/main.rs b/candle-examples/examples/segment-anything/main.rs index 3ef5762e..9ce2f158 100644 --- a/candle-examples/examples/segment-anything/main.rs +++ b/candle-examples/examples/segment-anything/main.rs @@ -133,6 +133,10 @@ struct Args { /// Enable tracing (generates a trace-timestamp.json file). #[arg(long)] tracing: bool, + + /// Use the TinyViT based models from MobileSAM + #[arg(long)] + use_tiny: bool, } pub fn main() -> anyhow::Result<()> { @@ -179,13 +183,22 @@ pub fn main() -> anyhow::Result<()> { None => { let api = hf_hub::api::sync::Api::new()?; let api = api.model("lmz/candle-sam".to_string()); - api.get("sam_vit_b_01ec64.safetensors")? + let filename = if args.use_tiny { + "mobile_sam-tiny-vitt.safetensors" + } else { + "sam_vit_b_01ec64.safetensors" + }; + api.get(filename)? } }; let weights = unsafe { candle::safetensors::MmapedFile::new(model)? }; let weights = weights.deserialize()?; let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device); - let sam = model_sam::Sam::new(768, 12, 12, &[2, 5, 8, 11], vb)?; // sam_vit_b + let sam = if args.use_tiny { + model_sam::Sam::new_tiny(vb)? // tiny vit_t + } else { + model_sam::Sam::new(768, 12, 12, &[2, 5, 8, 11], vb)? // sam_vit_b + }; if args.generate_masks { // Default options similar to the Python version. diff --git a/candle-examples/examples/segment-anything/model_sam.rs b/candle-examples/examples/segment-anything/model_sam.rs index ade976c1..b1a81af6 100644 --- a/candle-examples/examples/segment-anything/model_sam.rs +++ b/candle-examples/examples/segment-anything/model_sam.rs @@ -4,6 +4,7 @@ use candle_nn::{Module, VarBuilder}; use crate::model_image_encoder::ImageEncoderViT; use crate::model_mask_decoder::MaskDecoder; use crate::model_prompt_encoder::PromptEncoder; +use crate::model_tiny_vit::{tiny_vit_5m, TinyViT}; const PROMPT_EMBED_DIM: usize = 256; pub const IMAGE_SIZE: usize = 1024; @@ -14,9 +15,24 @@ const STABILITY_SCORE_THRESHOLD: f32 = 0.95; const MODEL_MASK_THRESHOLD: f32 = 0.0; const CROP_NMS_THRESH: f32 = 0.7; +#[derive(Debug)] +enum ImageEncoder { + Original(ImageEncoderViT), + TinyViT(TinyViT), +} + +impl Module for ImageEncoder { + fn forward(&self, xs: &Tensor) -> Result { + match self { + Self::Original(vit) => vit.forward(xs), + Self::TinyViT(vit) => vit.forward(xs), + } + } +} + #[derive(Debug)] pub struct Sam { - image_encoder: ImageEncoderViT, + image_encoder: ImageEncoder, prompt_encoder: PromptEncoder, mask_decoder: MaskDecoder, pixel_mean: Tensor, @@ -67,7 +83,38 @@ impl Sam { let pixel_std = Tensor::new(&[58.395f32, 57.12, 57.375], vb.device())?.reshape((3, 1, 1))?; Ok(Self { - image_encoder, + image_encoder: ImageEncoder::Original(image_encoder), + prompt_encoder, + mask_decoder, + pixel_std, + pixel_mean, + }) + } + + pub fn new_tiny(vb: VarBuilder) -> Result { + let image_embedding_size = IMAGE_SIZE / VIT_PATCH_SIZE; + + let image_encoder = tiny_vit_5m(vb.pp("image_encoder"))?; + let prompt_encoder = PromptEncoder::new( + PROMPT_EMBED_DIM, + (image_embedding_size, image_embedding_size), + (IMAGE_SIZE, IMAGE_SIZE), + 16, + vb.pp("prompt_encoder"), + )?; + let mask_decoder = MaskDecoder::new( + PROMPT_EMBED_DIM, + /* num_multitask_outputs */ 3, + /* iou_head_depth */ 3, + /* iou_head_hidden_dim */ 256, + vb.pp("mask_decoder"), + )?; + let pixel_mean = + Tensor::new(&[123.675f32, 116.28, 103.53], vb.device())?.reshape((3, 1, 1))?; + let pixel_std = + Tensor::new(&[58.395f32, 57.12, 57.375], vb.device())?.reshape((3, 1, 1))?; + Ok(Self { + image_encoder: ImageEncoder::TinyViT(image_encoder), prompt_encoder, mask_decoder, pixel_std, diff --git a/candle-examples/examples/segment-anything/model_tiny_vit.rs b/candle-examples/examples/segment-anything/model_tiny_vit.rs index 36e4c578..b3941ee1 100644 --- a/candle-examples/examples/segment-anything/model_tiny_vit.rs +++ b/candle-examples/examples/segment-anything/model_tiny_vit.rs @@ -1,13 +1,12 @@ // Adapted from: // https://github.com/ChaoningZhang/MobileSAM/blob/master/mobile_sam/modeling/tiny_vit_sam.py -#![allow(unused)] -use candle::{DType, IndexOp, Result, Tensor, D}; +use candle::{IndexOp, Result, Tensor, D}; use candle_nn::{Conv2dConfig, Module, VarBuilder}; const MBCONV_EXPAND_RATIO: usize = 4; const MLP_RATIO: usize = 4; const LOCAL_CONV_SIZE: usize = 3; -const IMG_SIZE: usize = 224; +const IMG_SIZE: usize = 1024; const IN_CHANNELS: usize = 3; #[derive(Debug)] @@ -18,7 +17,7 @@ struct Conv2dBN { impl Conv2dBN { fn new(in_: usize, out: usize, ks: usize, cfg: Conv2dConfig, vb: VarBuilder) -> Result { - let c = candle_nn::conv2d(in_, out, ks, cfg, vb.pp("c"))?; + let c = candle_nn::conv2d_no_bias(in_, out, ks, cfg, vb.pp("c"))?; let bn = candle_nn::batch_norm(out, 1e-5, vb.pp("bn"))?; Ok(Self { c, bn }) } @@ -222,7 +221,6 @@ struct Attention { norm: candle_nn::LayerNorm, qkv: candle_nn::Linear, proj: candle_nn::Linear, - attention_biases: Tensor, ab: Tensor, key_dim: usize, num_heads: usize, @@ -263,12 +261,14 @@ impl Attention { } let attention_biases = vb.get((num_heads, attention_offsets.len()), "attention_biases")?; let idxs = Tensor::new(idxs, attention_biases.device())?; - let ab = attention_biases.index_select(&idxs, 1)?; + let ab = + attention_biases + .index_select(&idxs, 1)? + .reshape(((), points.len(), points.len()))?; Ok(Self { norm, qkv, proj, - attention_biases, ab, key_dim, num_heads, @@ -286,15 +286,18 @@ impl Module for Attention { let qkv = xs.apply(&self.qkv)?.reshape((b, n, self.num_heads, ()))?; let q = qkv .narrow(D::Minus1, 0, self.key_dim)? - .permute((0, 2, 1, 3))?; + .permute((0, 2, 1, 3))? + .contiguous()?; let k = qkv .narrow(D::Minus1, self.key_dim, self.key_dim)? - .permute((0, 2, 1, 3))?; + .permute((0, 2, 1, 3))? + .contiguous()?; let v = qkv .narrow(D::Minus1, 2 * self.key_dim, self.d)? - .permute((0, 2, 1, 3))?; + .permute((0, 2, 1, 3))? + .contiguous()?; let attn = (q.matmul(&k.t()?)? * self.scale)?; - let attn = (attn + &self.ab)?; + let attn = attn.broadcast_add(&self.ab)?; let attn = candle_nn::ops::softmax_last_dim(&attn)?; attn.matmul(&v)? .transpose(1, 2)? @@ -332,6 +335,7 @@ impl TinyViTBlock { let mlp = Mlp::new(dim, dim * MLP_RATIO, vb.pp("mlp"))?; let cfg = candle_nn::Conv2dConfig { padding: LOCAL_CONV_SIZE / 2, + groups: dim, ..Default::default() }; let local_conv = Conv2dBN::new(dim, dim, LOCAL_CONV_SIZE, cfg, vb.pp("local_conv"))?; @@ -358,12 +362,12 @@ impl Module for TinyViTBlock { let pad_r = (self.window_size - w % self.window_size) % self.window_size; let xs = if pad_b > 0 { - xs.pad_with_zeros(D::Minus2, 0, pad_b)? + xs.pad_with_zeros(1, 0, pad_b)? } else { xs }; let xs = if pad_r > 0 { - xs.pad_with_zeros(D::Minus1, 0, pad_r)? + xs.pad_with_zeros(2, 0, pad_r)? } else { xs }; @@ -460,8 +464,8 @@ pub struct TinyViT { patch_embed: PatchEmbed, layer0: ConvLayer, layers: Vec, - norm_head: candle_nn::LayerNorm, - head: candle_nn::Linear, + // norm_head: candle_nn::LayerNorm, + // head: candle_nn::Linear, neck_conv1: candle_nn::Conv2d, neck_ln1: crate::LayerNorm2d, neck_conv2: candle_nn::Conv2d, @@ -474,7 +478,7 @@ impl TinyViT { depths: &[usize], num_heads: &[usize], window_sizes: &[usize], - num_classes: usize, + _num_classes: usize, vb: VarBuilder, ) -> Result { let patch_embed = PatchEmbed::new(IN_CHANNELS, embed_dims[0], vb.pp("patch_embed"))?; @@ -509,8 +513,8 @@ impl TinyViT { } let last_embed_dim = embed_dims[embed_dims.len() - 1]; - let norm_head = candle_nn::layer_norm(last_embed_dim, 1e-5, vb.pp("norm_head"))?; - let head = candle_nn::linear(last_embed_dim, num_classes, vb.pp("head"))?; + // let norm_head = candle_nn::layer_norm(last_embed_dim, 1e-5, vb.pp("norm_head"))?; + // let head = candle_nn::linear(last_embed_dim, num_classes, vb.pp("head"))?; let neck_conv1 = candle_nn::conv2d_no_bias(last_embed_dim, 256, 1, Default::default(), vb.pp("neck.0"))?; let neck_ln1 = crate::LayerNorm2d::new(256, 1e-6, vb.pp("neck.1"))?; @@ -525,8 +529,6 @@ impl TinyViT { patch_embed, layer0, layers, - norm_head, - head, neck_conv1, neck_ln1, neck_conv2, @@ -537,7 +539,8 @@ impl TinyViT { impl Module for TinyViT { fn forward(&self, xs: &Tensor) -> Result { - let mut xs = self.patch_embed.forward(xs)?; + let xs = self.patch_embed.forward(xs)?; + let mut xs = self.layer0.forward(&xs)?; for layer in self.layers.iter() { xs = layer.forward(&xs)? } @@ -551,7 +554,7 @@ impl Module for TinyViT { } } -pub fn tiny_vit_5m_224(vb: VarBuilder) -> Result { +pub fn tiny_vit_5m(vb: VarBuilder) -> Result { TinyViT::new( /* embed_dims */ &[64, 128, 160, 320], /* depths */ &[2, 2, 6, 2],