mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Get the MobileSAM TinyViT based version to work. (#789)
* More TinyViT support in SA. * More mobilesam work. * Add the mobile-sam weights to the hub.
This commit is contained in:
@ -133,6 +133,10 @@ struct Args {
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
/// Use the TinyViT based models from MobileSAM
|
||||
#[arg(long)]
|
||||
use_tiny: bool,
|
||||
}
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
@ -179,13 +183,22 @@ pub fn main() -> anyhow::Result<()> {
|
||||
None => {
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.model("lmz/candle-sam".to_string());
|
||||
api.get("sam_vit_b_01ec64.safetensors")?
|
||||
let filename = if args.use_tiny {
|
||||
"mobile_sam-tiny-vitt.safetensors"
|
||||
} else {
|
||||
"sam_vit_b_01ec64.safetensors"
|
||||
};
|
||||
api.get(filename)?
|
||||
}
|
||||
};
|
||||
let weights = unsafe { candle::safetensors::MmapedFile::new(model)? };
|
||||
let weights = weights.deserialize()?;
|
||||
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
|
||||
let sam = model_sam::Sam::new(768, 12, 12, &[2, 5, 8, 11], vb)?; // sam_vit_b
|
||||
let sam = if args.use_tiny {
|
||||
model_sam::Sam::new_tiny(vb)? // tiny vit_t
|
||||
} else {
|
||||
model_sam::Sam::new(768, 12, 12, &[2, 5, 8, 11], vb)? // sam_vit_b
|
||||
};
|
||||
|
||||
if args.generate_masks {
|
||||
// Default options similar to the Python version.
|
||||
|
Reference in New Issue
Block a user