mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Get the MobileSAM TinyViT based version to work. (#789)
* More TinyViT support in SA. * More mobilesam work. * Add the mobile-sam weights to the hub.
This commit is contained in:
@ -133,6 +133,10 @@ struct Args {
|
|||||||
/// Enable tracing (generates a trace-timestamp.json file).
|
/// Enable tracing (generates a trace-timestamp.json file).
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
tracing: bool,
|
tracing: bool,
|
||||||
|
|
||||||
|
/// Use the TinyViT based models from MobileSAM
|
||||||
|
#[arg(long)]
|
||||||
|
use_tiny: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn main() -> anyhow::Result<()> {
|
pub fn main() -> anyhow::Result<()> {
|
||||||
@ -179,13 +183,22 @@ pub fn main() -> anyhow::Result<()> {
|
|||||||
None => {
|
None => {
|
||||||
let api = hf_hub::api::sync::Api::new()?;
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
let api = api.model("lmz/candle-sam".to_string());
|
let api = api.model("lmz/candle-sam".to_string());
|
||||||
api.get("sam_vit_b_01ec64.safetensors")?
|
let filename = if args.use_tiny {
|
||||||
|
"mobile_sam-tiny-vitt.safetensors"
|
||||||
|
} else {
|
||||||
|
"sam_vit_b_01ec64.safetensors"
|
||||||
|
};
|
||||||
|
api.get(filename)?
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
let weights = unsafe { candle::safetensors::MmapedFile::new(model)? };
|
let weights = unsafe { candle::safetensors::MmapedFile::new(model)? };
|
||||||
let weights = weights.deserialize()?;
|
let weights = weights.deserialize()?;
|
||||||
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
|
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
|
||||||
let sam = model_sam::Sam::new(768, 12, 12, &[2, 5, 8, 11], vb)?; // sam_vit_b
|
let sam = if args.use_tiny {
|
||||||
|
model_sam::Sam::new_tiny(vb)? // tiny vit_t
|
||||||
|
} else {
|
||||||
|
model_sam::Sam::new(768, 12, 12, &[2, 5, 8, 11], vb)? // sam_vit_b
|
||||||
|
};
|
||||||
|
|
||||||
if args.generate_masks {
|
if args.generate_masks {
|
||||||
// Default options similar to the Python version.
|
// Default options similar to the Python version.
|
||||||
|
@ -4,6 +4,7 @@ use candle_nn::{Module, VarBuilder};
|
|||||||
use crate::model_image_encoder::ImageEncoderViT;
|
use crate::model_image_encoder::ImageEncoderViT;
|
||||||
use crate::model_mask_decoder::MaskDecoder;
|
use crate::model_mask_decoder::MaskDecoder;
|
||||||
use crate::model_prompt_encoder::PromptEncoder;
|
use crate::model_prompt_encoder::PromptEncoder;
|
||||||
|
use crate::model_tiny_vit::{tiny_vit_5m, TinyViT};
|
||||||
|
|
||||||
const PROMPT_EMBED_DIM: usize = 256;
|
const PROMPT_EMBED_DIM: usize = 256;
|
||||||
pub const IMAGE_SIZE: usize = 1024;
|
pub const IMAGE_SIZE: usize = 1024;
|
||||||
@ -14,9 +15,24 @@ const STABILITY_SCORE_THRESHOLD: f32 = 0.95;
|
|||||||
const MODEL_MASK_THRESHOLD: f32 = 0.0;
|
const MODEL_MASK_THRESHOLD: f32 = 0.0;
|
||||||
const CROP_NMS_THRESH: f32 = 0.7;
|
const CROP_NMS_THRESH: f32 = 0.7;
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
enum ImageEncoder {
|
||||||
|
Original(ImageEncoderViT),
|
||||||
|
TinyViT(TinyViT),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for ImageEncoder {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
match self {
|
||||||
|
Self::Original(vit) => vit.forward(xs),
|
||||||
|
Self::TinyViT(vit) => vit.forward(xs),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct Sam {
|
pub struct Sam {
|
||||||
image_encoder: ImageEncoderViT,
|
image_encoder: ImageEncoder,
|
||||||
prompt_encoder: PromptEncoder,
|
prompt_encoder: PromptEncoder,
|
||||||
mask_decoder: MaskDecoder,
|
mask_decoder: MaskDecoder,
|
||||||
pixel_mean: Tensor,
|
pixel_mean: Tensor,
|
||||||
@ -67,7 +83,38 @@ impl Sam {
|
|||||||
let pixel_std =
|
let pixel_std =
|
||||||
Tensor::new(&[58.395f32, 57.12, 57.375], vb.device())?.reshape((3, 1, 1))?;
|
Tensor::new(&[58.395f32, 57.12, 57.375], vb.device())?.reshape((3, 1, 1))?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
image_encoder,
|
image_encoder: ImageEncoder::Original(image_encoder),
|
||||||
|
prompt_encoder,
|
||||||
|
mask_decoder,
|
||||||
|
pixel_std,
|
||||||
|
pixel_mean,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn new_tiny(vb: VarBuilder) -> Result<Self> {
|
||||||
|
let image_embedding_size = IMAGE_SIZE / VIT_PATCH_SIZE;
|
||||||
|
|
||||||
|
let image_encoder = tiny_vit_5m(vb.pp("image_encoder"))?;
|
||||||
|
let prompt_encoder = PromptEncoder::new(
|
||||||
|
PROMPT_EMBED_DIM,
|
||||||
|
(image_embedding_size, image_embedding_size),
|
||||||
|
(IMAGE_SIZE, IMAGE_SIZE),
|
||||||
|
16,
|
||||||
|
vb.pp("prompt_encoder"),
|
||||||
|
)?;
|
||||||
|
let mask_decoder = MaskDecoder::new(
|
||||||
|
PROMPT_EMBED_DIM,
|
||||||
|
/* num_multitask_outputs */ 3,
|
||||||
|
/* iou_head_depth */ 3,
|
||||||
|
/* iou_head_hidden_dim */ 256,
|
||||||
|
vb.pp("mask_decoder"),
|
||||||
|
)?;
|
||||||
|
let pixel_mean =
|
||||||
|
Tensor::new(&[123.675f32, 116.28, 103.53], vb.device())?.reshape((3, 1, 1))?;
|
||||||
|
let pixel_std =
|
||||||
|
Tensor::new(&[58.395f32, 57.12, 57.375], vb.device())?.reshape((3, 1, 1))?;
|
||||||
|
Ok(Self {
|
||||||
|
image_encoder: ImageEncoder::TinyViT(image_encoder),
|
||||||
prompt_encoder,
|
prompt_encoder,
|
||||||
mask_decoder,
|
mask_decoder,
|
||||||
pixel_std,
|
pixel_std,
|
||||||
|
@ -1,13 +1,12 @@
|
|||||||
// Adapted from:
|
// Adapted from:
|
||||||
// https://github.com/ChaoningZhang/MobileSAM/blob/master/mobile_sam/modeling/tiny_vit_sam.py
|
// https://github.com/ChaoningZhang/MobileSAM/blob/master/mobile_sam/modeling/tiny_vit_sam.py
|
||||||
#![allow(unused)]
|
use candle::{IndexOp, Result, Tensor, D};
|
||||||
use candle::{DType, IndexOp, Result, Tensor, D};
|
|
||||||
use candle_nn::{Conv2dConfig, Module, VarBuilder};
|
use candle_nn::{Conv2dConfig, Module, VarBuilder};
|
||||||
|
|
||||||
const MBCONV_EXPAND_RATIO: usize = 4;
|
const MBCONV_EXPAND_RATIO: usize = 4;
|
||||||
const MLP_RATIO: usize = 4;
|
const MLP_RATIO: usize = 4;
|
||||||
const LOCAL_CONV_SIZE: usize = 3;
|
const LOCAL_CONV_SIZE: usize = 3;
|
||||||
const IMG_SIZE: usize = 224;
|
const IMG_SIZE: usize = 1024;
|
||||||
const IN_CHANNELS: usize = 3;
|
const IN_CHANNELS: usize = 3;
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
@ -18,7 +17,7 @@ struct Conv2dBN {
|
|||||||
|
|
||||||
impl Conv2dBN {
|
impl Conv2dBN {
|
||||||
fn new(in_: usize, out: usize, ks: usize, cfg: Conv2dConfig, vb: VarBuilder) -> Result<Self> {
|
fn new(in_: usize, out: usize, ks: usize, cfg: Conv2dConfig, vb: VarBuilder) -> Result<Self> {
|
||||||
let c = candle_nn::conv2d(in_, out, ks, cfg, vb.pp("c"))?;
|
let c = candle_nn::conv2d_no_bias(in_, out, ks, cfg, vb.pp("c"))?;
|
||||||
let bn = candle_nn::batch_norm(out, 1e-5, vb.pp("bn"))?;
|
let bn = candle_nn::batch_norm(out, 1e-5, vb.pp("bn"))?;
|
||||||
Ok(Self { c, bn })
|
Ok(Self { c, bn })
|
||||||
}
|
}
|
||||||
@ -222,7 +221,6 @@ struct Attention {
|
|||||||
norm: candle_nn::LayerNorm,
|
norm: candle_nn::LayerNorm,
|
||||||
qkv: candle_nn::Linear,
|
qkv: candle_nn::Linear,
|
||||||
proj: candle_nn::Linear,
|
proj: candle_nn::Linear,
|
||||||
attention_biases: Tensor,
|
|
||||||
ab: Tensor,
|
ab: Tensor,
|
||||||
key_dim: usize,
|
key_dim: usize,
|
||||||
num_heads: usize,
|
num_heads: usize,
|
||||||
@ -263,12 +261,14 @@ impl Attention {
|
|||||||
}
|
}
|
||||||
let attention_biases = vb.get((num_heads, attention_offsets.len()), "attention_biases")?;
|
let attention_biases = vb.get((num_heads, attention_offsets.len()), "attention_biases")?;
|
||||||
let idxs = Tensor::new(idxs, attention_biases.device())?;
|
let idxs = Tensor::new(idxs, attention_biases.device())?;
|
||||||
let ab = attention_biases.index_select(&idxs, 1)?;
|
let ab =
|
||||||
|
attention_biases
|
||||||
|
.index_select(&idxs, 1)?
|
||||||
|
.reshape(((), points.len(), points.len()))?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
norm,
|
norm,
|
||||||
qkv,
|
qkv,
|
||||||
proj,
|
proj,
|
||||||
attention_biases,
|
|
||||||
ab,
|
ab,
|
||||||
key_dim,
|
key_dim,
|
||||||
num_heads,
|
num_heads,
|
||||||
@ -286,15 +286,18 @@ impl Module for Attention {
|
|||||||
let qkv = xs.apply(&self.qkv)?.reshape((b, n, self.num_heads, ()))?;
|
let qkv = xs.apply(&self.qkv)?.reshape((b, n, self.num_heads, ()))?;
|
||||||
let q = qkv
|
let q = qkv
|
||||||
.narrow(D::Minus1, 0, self.key_dim)?
|
.narrow(D::Minus1, 0, self.key_dim)?
|
||||||
.permute((0, 2, 1, 3))?;
|
.permute((0, 2, 1, 3))?
|
||||||
|
.contiguous()?;
|
||||||
let k = qkv
|
let k = qkv
|
||||||
.narrow(D::Minus1, self.key_dim, self.key_dim)?
|
.narrow(D::Minus1, self.key_dim, self.key_dim)?
|
||||||
.permute((0, 2, 1, 3))?;
|
.permute((0, 2, 1, 3))?
|
||||||
|
.contiguous()?;
|
||||||
let v = qkv
|
let v = qkv
|
||||||
.narrow(D::Minus1, 2 * self.key_dim, self.d)?
|
.narrow(D::Minus1, 2 * self.key_dim, self.d)?
|
||||||
.permute((0, 2, 1, 3))?;
|
.permute((0, 2, 1, 3))?
|
||||||
|
.contiguous()?;
|
||||||
let attn = (q.matmul(&k.t()?)? * self.scale)?;
|
let attn = (q.matmul(&k.t()?)? * self.scale)?;
|
||||||
let attn = (attn + &self.ab)?;
|
let attn = attn.broadcast_add(&self.ab)?;
|
||||||
let attn = candle_nn::ops::softmax_last_dim(&attn)?;
|
let attn = candle_nn::ops::softmax_last_dim(&attn)?;
|
||||||
attn.matmul(&v)?
|
attn.matmul(&v)?
|
||||||
.transpose(1, 2)?
|
.transpose(1, 2)?
|
||||||
@ -332,6 +335,7 @@ impl TinyViTBlock {
|
|||||||
let mlp = Mlp::new(dim, dim * MLP_RATIO, vb.pp("mlp"))?;
|
let mlp = Mlp::new(dim, dim * MLP_RATIO, vb.pp("mlp"))?;
|
||||||
let cfg = candle_nn::Conv2dConfig {
|
let cfg = candle_nn::Conv2dConfig {
|
||||||
padding: LOCAL_CONV_SIZE / 2,
|
padding: LOCAL_CONV_SIZE / 2,
|
||||||
|
groups: dim,
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
let local_conv = Conv2dBN::new(dim, dim, LOCAL_CONV_SIZE, cfg, vb.pp("local_conv"))?;
|
let local_conv = Conv2dBN::new(dim, dim, LOCAL_CONV_SIZE, cfg, vb.pp("local_conv"))?;
|
||||||
@ -358,12 +362,12 @@ impl Module for TinyViTBlock {
|
|||||||
let pad_r = (self.window_size - w % self.window_size) % self.window_size;
|
let pad_r = (self.window_size - w % self.window_size) % self.window_size;
|
||||||
|
|
||||||
let xs = if pad_b > 0 {
|
let xs = if pad_b > 0 {
|
||||||
xs.pad_with_zeros(D::Minus2, 0, pad_b)?
|
xs.pad_with_zeros(1, 0, pad_b)?
|
||||||
} else {
|
} else {
|
||||||
xs
|
xs
|
||||||
};
|
};
|
||||||
let xs = if pad_r > 0 {
|
let xs = if pad_r > 0 {
|
||||||
xs.pad_with_zeros(D::Minus1, 0, pad_r)?
|
xs.pad_with_zeros(2, 0, pad_r)?
|
||||||
} else {
|
} else {
|
||||||
xs
|
xs
|
||||||
};
|
};
|
||||||
@ -460,8 +464,8 @@ pub struct TinyViT {
|
|||||||
patch_embed: PatchEmbed,
|
patch_embed: PatchEmbed,
|
||||||
layer0: ConvLayer,
|
layer0: ConvLayer,
|
||||||
layers: Vec<BasicLayer>,
|
layers: Vec<BasicLayer>,
|
||||||
norm_head: candle_nn::LayerNorm,
|
// norm_head: candle_nn::LayerNorm,
|
||||||
head: candle_nn::Linear,
|
// head: candle_nn::Linear,
|
||||||
neck_conv1: candle_nn::Conv2d,
|
neck_conv1: candle_nn::Conv2d,
|
||||||
neck_ln1: crate::LayerNorm2d,
|
neck_ln1: crate::LayerNorm2d,
|
||||||
neck_conv2: candle_nn::Conv2d,
|
neck_conv2: candle_nn::Conv2d,
|
||||||
@ -474,7 +478,7 @@ impl TinyViT {
|
|||||||
depths: &[usize],
|
depths: &[usize],
|
||||||
num_heads: &[usize],
|
num_heads: &[usize],
|
||||||
window_sizes: &[usize],
|
window_sizes: &[usize],
|
||||||
num_classes: usize,
|
_num_classes: usize,
|
||||||
vb: VarBuilder,
|
vb: VarBuilder,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let patch_embed = PatchEmbed::new(IN_CHANNELS, embed_dims[0], vb.pp("patch_embed"))?;
|
let patch_embed = PatchEmbed::new(IN_CHANNELS, embed_dims[0], vb.pp("patch_embed"))?;
|
||||||
@ -509,8 +513,8 @@ impl TinyViT {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let last_embed_dim = embed_dims[embed_dims.len() - 1];
|
let last_embed_dim = embed_dims[embed_dims.len() - 1];
|
||||||
let norm_head = candle_nn::layer_norm(last_embed_dim, 1e-5, vb.pp("norm_head"))?;
|
// let norm_head = candle_nn::layer_norm(last_embed_dim, 1e-5, vb.pp("norm_head"))?;
|
||||||
let head = candle_nn::linear(last_embed_dim, num_classes, vb.pp("head"))?;
|
// let head = candle_nn::linear(last_embed_dim, num_classes, vb.pp("head"))?;
|
||||||
let neck_conv1 =
|
let neck_conv1 =
|
||||||
candle_nn::conv2d_no_bias(last_embed_dim, 256, 1, Default::default(), vb.pp("neck.0"))?;
|
candle_nn::conv2d_no_bias(last_embed_dim, 256, 1, Default::default(), vb.pp("neck.0"))?;
|
||||||
let neck_ln1 = crate::LayerNorm2d::new(256, 1e-6, vb.pp("neck.1"))?;
|
let neck_ln1 = crate::LayerNorm2d::new(256, 1e-6, vb.pp("neck.1"))?;
|
||||||
@ -525,8 +529,6 @@ impl TinyViT {
|
|||||||
patch_embed,
|
patch_embed,
|
||||||
layer0,
|
layer0,
|
||||||
layers,
|
layers,
|
||||||
norm_head,
|
|
||||||
head,
|
|
||||||
neck_conv1,
|
neck_conv1,
|
||||||
neck_ln1,
|
neck_ln1,
|
||||||
neck_conv2,
|
neck_conv2,
|
||||||
@ -537,7 +539,8 @@ impl TinyViT {
|
|||||||
|
|
||||||
impl Module for TinyViT {
|
impl Module for TinyViT {
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
let mut xs = self.patch_embed.forward(xs)?;
|
let xs = self.patch_embed.forward(xs)?;
|
||||||
|
let mut xs = self.layer0.forward(&xs)?;
|
||||||
for layer in self.layers.iter() {
|
for layer in self.layers.iter() {
|
||||||
xs = layer.forward(&xs)?
|
xs = layer.forward(&xs)?
|
||||||
}
|
}
|
||||||
@ -551,7 +554,7 @@ impl Module for TinyViT {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn tiny_vit_5m_224(vb: VarBuilder) -> Result<TinyViT> {
|
pub fn tiny_vit_5m(vb: VarBuilder) -> Result<TinyViT> {
|
||||||
TinyViT::new(
|
TinyViT::new(
|
||||||
/* embed_dims */ &[64, 128, 160, 320],
|
/* embed_dims */ &[64, 128, 160, 320],
|
||||||
/* depths */ &[2, 2, 6, 2],
|
/* depths */ &[2, 2, 6, 2],
|
||||||
|
Reference in New Issue
Block a user