Get the MobileSAM TinyViT based version to work. (#789)

* More TinyViT support in SA.

* More mobilesam work.

* Add the mobile-sam weights to the hub.
This commit is contained in:
Laurent Mazare
2023-09-09 16:21:44 +01:00
committed by GitHub
parent b7cd58473b
commit 74ad4deb42
3 changed files with 89 additions and 26 deletions

View File

@ -133,6 +133,10 @@ struct Args {
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
/// Use the TinyViT based models from MobileSAM
#[arg(long)]
use_tiny: bool,
}
pub fn main() -> anyhow::Result<()> {
@ -179,13 +183,22 @@ pub fn main() -> anyhow::Result<()> {
None => {
let api = hf_hub::api::sync::Api::new()?;
let api = api.model("lmz/candle-sam".to_string());
api.get("sam_vit_b_01ec64.safetensors")?
let filename = if args.use_tiny {
"mobile_sam-tiny-vitt.safetensors"
} else {
"sam_vit_b_01ec64.safetensors"
};
api.get(filename)?
}
};
let weights = unsafe { candle::safetensors::MmapedFile::new(model)? };
let weights = weights.deserialize()?;
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
let sam = model_sam::Sam::new(768, 12, 12, &[2, 5, 8, 11], vb)?; // sam_vit_b
let sam = if args.use_tiny {
model_sam::Sam::new_tiny(vb)? // tiny vit_t
} else {
model_sam::Sam::new(768, 12, 12, &[2, 5, 8, 11], vb)? // sam_vit_b
};
if args.generate_masks {
// Default options similar to the Python version.