mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 02:16:37 +00:00
Parse the json config for siglip models. (#2800)
* Parse the json config for siglip models. * Bump the tokenizers dependency. * Add a v2 model. * Support more v2 model.s
This commit is contained in:
@ -13,11 +13,40 @@ use candle_transformers::models::siglip;
|
||||
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
#[derive(Clone, Copy, Debug, clap::ValueEnum, PartialEq, Eq)]
|
||||
enum Which {
|
||||
#[value(name = "v1-base-patch16-224")]
|
||||
V1BasePatch16_224,
|
||||
#[value(name = "v2-base-patch16-224")]
|
||||
V2BasePatch16_224,
|
||||
#[value(name = "v2-base-patch16-256")]
|
||||
V2BasePatch16_256,
|
||||
#[value(name = "v2-base-patch16-384")]
|
||||
V2BasePatch16_384,
|
||||
#[value(name = "v2-base-patch16-512")]
|
||||
V2BasePatch16_512,
|
||||
#[value(name = "v2-large-patch16-256")]
|
||||
V2LargePatch16_256,
|
||||
#[value(name = "v2-large-patch16-384")]
|
||||
V2LargePatch16_384,
|
||||
#[value(name = "v2-large-patch16-512")]
|
||||
V2LargePatch16_512,
|
||||
}
|
||||
|
||||
#[derive(Parser)]
|
||||
struct Args {
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
config: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
hf_repo: Option<String>,
|
||||
|
||||
#[arg(long, default_value = "v1-base-patch16-224")]
|
||||
which: Which,
|
||||
|
||||
#[arg(long)]
|
||||
tokenizer: Option<String>,
|
||||
|
||||
@ -66,16 +95,37 @@ fn load_images<T: AsRef<std::path::Path>>(
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
let hf_repo = match args.hf_repo.as_ref() {
|
||||
Some(hf_repo) => hf_repo,
|
||||
None => match args.which {
|
||||
Which::V1BasePatch16_224 => "google/siglip-base-patch16-224",
|
||||
Which::V2BasePatch16_224 => "google/siglip2-base-patch16-224",
|
||||
Which::V2BasePatch16_256 => "google/siglip2-base-patch16-256",
|
||||
Which::V2BasePatch16_384 => "google/siglip2-base-patch16-384",
|
||||
Which::V2BasePatch16_512 => "google/siglip2-base-patch16-512",
|
||||
Which::V2LargePatch16_256 => "google/siglip2-large-patch16-256",
|
||||
Which::V2LargePatch16_384 => "google/siglip2-large-patch16-384",
|
||||
Which::V2LargePatch16_512 => "google/siglip2-large-patch16-512",
|
||||
},
|
||||
};
|
||||
let model_file = match args.model {
|
||||
None => {
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.model("google/siglip-base-patch16-224".to_string());
|
||||
let api = api.model(hf_repo.to_string());
|
||||
api.get("model.safetensors")?
|
||||
}
|
||||
Some(model) => model.into(),
|
||||
};
|
||||
let tokenizer = get_tokenizer(args.tokenizer)?;
|
||||
let config = siglip::Config::base_patch16_224();
|
||||
let config_file = match args.config {
|
||||
None => {
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.model(hf_repo.to_string());
|
||||
api.get("config.json")?
|
||||
}
|
||||
Some(config) => config.into(),
|
||||
};
|
||||
let tokenizer = get_tokenizer(hf_repo, args.tokenizer)?;
|
||||
let config: siglip::Config = serde_json::from_slice(&std::fs::read(config_file)?)?;
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let vec_imgs = match args.images {
|
||||
Some(imgs) => imgs,
|
||||
@ -114,11 +164,11 @@ pub fn main() -> anyhow::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn get_tokenizer(tokenizer: Option<String>) -> anyhow::Result<Tokenizer> {
|
||||
pub fn get_tokenizer(hf_repo: &str, tokenizer: Option<String>) -> anyhow::Result<Tokenizer> {
|
||||
let tokenizer = match tokenizer {
|
||||
None => {
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.model("google/siglip-base-patch16-224".to_string());
|
||||
let api = api.model(hf_repo.to_string());
|
||||
api.get("tokenizer.json")?
|
||||
}
|
||||
Some(file) => file.into(),
|
||||
|
Reference in New Issue
Block a user