mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 03:28:50 +00:00
Get the MobileSAM TinyViT based version to work. (#789)
* More TinyViT support in SA. * More mobilesam work. * Add the mobile-sam weights to the hub.
This commit is contained in:
@ -4,6 +4,7 @@ use candle_nn::{Module, VarBuilder};
|
||||
use crate::model_image_encoder::ImageEncoderViT;
|
||||
use crate::model_mask_decoder::MaskDecoder;
|
||||
use crate::model_prompt_encoder::PromptEncoder;
|
||||
use crate::model_tiny_vit::{tiny_vit_5m, TinyViT};
|
||||
|
||||
const PROMPT_EMBED_DIM: usize = 256;
|
||||
pub const IMAGE_SIZE: usize = 1024;
|
||||
@ -14,9 +15,24 @@ const STABILITY_SCORE_THRESHOLD: f32 = 0.95;
|
||||
const MODEL_MASK_THRESHOLD: f32 = 0.0;
|
||||
const CROP_NMS_THRESH: f32 = 0.7;
|
||||
|
||||
#[derive(Debug)]
|
||||
enum ImageEncoder {
|
||||
Original(ImageEncoderViT),
|
||||
TinyViT(TinyViT),
|
||||
}
|
||||
|
||||
impl Module for ImageEncoder {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
match self {
|
||||
Self::Original(vit) => vit.forward(xs),
|
||||
Self::TinyViT(vit) => vit.forward(xs),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Sam {
|
||||
image_encoder: ImageEncoderViT,
|
||||
image_encoder: ImageEncoder,
|
||||
prompt_encoder: PromptEncoder,
|
||||
mask_decoder: MaskDecoder,
|
||||
pixel_mean: Tensor,
|
||||
@ -67,7 +83,38 @@ impl Sam {
|
||||
let pixel_std =
|
||||
Tensor::new(&[58.395f32, 57.12, 57.375], vb.device())?.reshape((3, 1, 1))?;
|
||||
Ok(Self {
|
||||
image_encoder,
|
||||
image_encoder: ImageEncoder::Original(image_encoder),
|
||||
prompt_encoder,
|
||||
mask_decoder,
|
||||
pixel_std,
|
||||
pixel_mean,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn new_tiny(vb: VarBuilder) -> Result<Self> {
|
||||
let image_embedding_size = IMAGE_SIZE / VIT_PATCH_SIZE;
|
||||
|
||||
let image_encoder = tiny_vit_5m(vb.pp("image_encoder"))?;
|
||||
let prompt_encoder = PromptEncoder::new(
|
||||
PROMPT_EMBED_DIM,
|
||||
(image_embedding_size, image_embedding_size),
|
||||
(IMAGE_SIZE, IMAGE_SIZE),
|
||||
16,
|
||||
vb.pp("prompt_encoder"),
|
||||
)?;
|
||||
let mask_decoder = MaskDecoder::new(
|
||||
PROMPT_EMBED_DIM,
|
||||
/* num_multitask_outputs */ 3,
|
||||
/* iou_head_depth */ 3,
|
||||
/* iou_head_hidden_dim */ 256,
|
||||
vb.pp("mask_decoder"),
|
||||
)?;
|
||||
let pixel_mean =
|
||||
Tensor::new(&[123.675f32, 116.28, 103.53], vb.device())?.reshape((3, 1, 1))?;
|
||||
let pixel_std =
|
||||
Tensor::new(&[58.395f32, 57.12, 57.375], vb.device())?.reshape((3, 1, 1))?;
|
||||
Ok(Self {
|
||||
image_encoder: ImageEncoder::TinyViT(image_encoder),
|
||||
prompt_encoder,
|
||||
mask_decoder,
|
||||
pixel_std,
|
||||
|
Reference in New Issue
Block a user