mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Parse the json config for siglip models. (#2800)
* Parse the json config for siglip models. * Bump the tokenizers dependency. * Add a v2 model. * Support more v2 model.s
This commit is contained in:
@ -66,7 +66,7 @@ serde = { version = "1.0.171", features = ["derive"] }
|
|||||||
serde_plain = "1.0.2"
|
serde_plain = "1.0.2"
|
||||||
serde_json = "1.0.99"
|
serde_json = "1.0.99"
|
||||||
thiserror = "1"
|
thiserror = "1"
|
||||||
tokenizers = { version = "0.19.1", default-features = false }
|
tokenizers = { version = "0.21.0", default-features = false }
|
||||||
tracing = "0.1.37"
|
tracing = "0.1.37"
|
||||||
tracing-chrome = "0.7.1"
|
tracing-chrome = "0.7.1"
|
||||||
tracing-subscriber = "0.3.7"
|
tracing-subscriber = "0.3.7"
|
||||||
|
@ -13,11 +13,40 @@ use candle_transformers::models::siglip;
|
|||||||
|
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, Debug, clap::ValueEnum, PartialEq, Eq)]
|
||||||
|
enum Which {
|
||||||
|
#[value(name = "v1-base-patch16-224")]
|
||||||
|
V1BasePatch16_224,
|
||||||
|
#[value(name = "v2-base-patch16-224")]
|
||||||
|
V2BasePatch16_224,
|
||||||
|
#[value(name = "v2-base-patch16-256")]
|
||||||
|
V2BasePatch16_256,
|
||||||
|
#[value(name = "v2-base-patch16-384")]
|
||||||
|
V2BasePatch16_384,
|
||||||
|
#[value(name = "v2-base-patch16-512")]
|
||||||
|
V2BasePatch16_512,
|
||||||
|
#[value(name = "v2-large-patch16-256")]
|
||||||
|
V2LargePatch16_256,
|
||||||
|
#[value(name = "v2-large-patch16-384")]
|
||||||
|
V2LargePatch16_384,
|
||||||
|
#[value(name = "v2-large-patch16-512")]
|
||||||
|
V2LargePatch16_512,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Parser)]
|
#[derive(Parser)]
|
||||||
struct Args {
|
struct Args {
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
model: Option<String>,
|
model: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
config: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
hf_repo: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long, default_value = "v1-base-patch16-224")]
|
||||||
|
which: Which,
|
||||||
|
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
tokenizer: Option<String>,
|
tokenizer: Option<String>,
|
||||||
|
|
||||||
@ -66,16 +95,37 @@ fn load_images<T: AsRef<std::path::Path>>(
|
|||||||
|
|
||||||
pub fn main() -> anyhow::Result<()> {
|
pub fn main() -> anyhow::Result<()> {
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
|
let hf_repo = match args.hf_repo.as_ref() {
|
||||||
|
Some(hf_repo) => hf_repo,
|
||||||
|
None => match args.which {
|
||||||
|
Which::V1BasePatch16_224 => "google/siglip-base-patch16-224",
|
||||||
|
Which::V2BasePatch16_224 => "google/siglip2-base-patch16-224",
|
||||||
|
Which::V2BasePatch16_256 => "google/siglip2-base-patch16-256",
|
||||||
|
Which::V2BasePatch16_384 => "google/siglip2-base-patch16-384",
|
||||||
|
Which::V2BasePatch16_512 => "google/siglip2-base-patch16-512",
|
||||||
|
Which::V2LargePatch16_256 => "google/siglip2-large-patch16-256",
|
||||||
|
Which::V2LargePatch16_384 => "google/siglip2-large-patch16-384",
|
||||||
|
Which::V2LargePatch16_512 => "google/siglip2-large-patch16-512",
|
||||||
|
},
|
||||||
|
};
|
||||||
let model_file = match args.model {
|
let model_file = match args.model {
|
||||||
None => {
|
None => {
|
||||||
let api = hf_hub::api::sync::Api::new()?;
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
let api = api.model("google/siglip-base-patch16-224".to_string());
|
let api = api.model(hf_repo.to_string());
|
||||||
api.get("model.safetensors")?
|
api.get("model.safetensors")?
|
||||||
}
|
}
|
||||||
Some(model) => model.into(),
|
Some(model) => model.into(),
|
||||||
};
|
};
|
||||||
let tokenizer = get_tokenizer(args.tokenizer)?;
|
let config_file = match args.config {
|
||||||
let config = siglip::Config::base_patch16_224();
|
None => {
|
||||||
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
|
let api = api.model(hf_repo.to_string());
|
||||||
|
api.get("config.json")?
|
||||||
|
}
|
||||||
|
Some(config) => config.into(),
|
||||||
|
};
|
||||||
|
let tokenizer = get_tokenizer(hf_repo, args.tokenizer)?;
|
||||||
|
let config: siglip::Config = serde_json::from_slice(&std::fs::read(config_file)?)?;
|
||||||
let device = candle_examples::device(args.cpu)?;
|
let device = candle_examples::device(args.cpu)?;
|
||||||
let vec_imgs = match args.images {
|
let vec_imgs = match args.images {
|
||||||
Some(imgs) => imgs,
|
Some(imgs) => imgs,
|
||||||
@ -114,11 +164,11 @@ pub fn main() -> anyhow::Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_tokenizer(tokenizer: Option<String>) -> anyhow::Result<Tokenizer> {
|
pub fn get_tokenizer(hf_repo: &str, tokenizer: Option<String>) -> anyhow::Result<Tokenizer> {
|
||||||
let tokenizer = match tokenizer {
|
let tokenizer = match tokenizer {
|
||||||
None => {
|
None => {
|
||||||
let api = hf_hub::api::sync::Api::new()?;
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
let api = api.model("google/siglip-base-patch16-224".to_string());
|
let api = api.model(hf_repo.to_string());
|
||||||
api.get("tokenizer.json")?
|
api.get("tokenizer.json")?
|
||||||
}
|
}
|
||||||
Some(file) => file.into(),
|
Some(file) => file.into(),
|
||||||
|
@ -10,33 +10,133 @@ use crate::models::clip::div_l2_norm;
|
|||||||
use candle::{IndexOp, Module, Result, Tensor, D};
|
use candle::{IndexOp, Module, Result, Tensor, D};
|
||||||
use candle_nn::{layer_norm, linear, LayerNorm, Linear, VarBuilder};
|
use candle_nn::{layer_norm, linear, LayerNorm, Linear, VarBuilder};
|
||||||
|
|
||||||
|
fn default_text_vocab_size() -> usize {
|
||||||
|
32000
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_text_hidden_size() -> usize {
|
||||||
|
768
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_text_intermediate_size() -> usize {
|
||||||
|
3072
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_text_num_hidden_layers() -> usize {
|
||||||
|
12
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_text_num_attention_heads() -> usize {
|
||||||
|
12
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_text_max_position_embeddings() -> usize {
|
||||||
|
64
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_text_layer_norm_eps() -> f64 {
|
||||||
|
1e-6
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_text_pad_token_id() -> u32 {
|
||||||
|
1
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_text_bos_token_id() -> u32 {
|
||||||
|
49406
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_text_eos_token_id() -> u32 {
|
||||||
|
49407
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_text_hidden_act() -> candle_nn::Activation {
|
||||||
|
candle_nn::Activation::GeluPytorchTanh
|
||||||
|
}
|
||||||
|
|
||||||
// https://github.com/huggingface/transformers/blob/2e24ee4dfa39cc0bc264b89edbccc373c8337086/src/transformers/models/siglip/configuration_siglip.py#L27
|
// https://github.com/huggingface/transformers/blob/2e24ee4dfa39cc0bc264b89edbccc373c8337086/src/transformers/models/siglip/configuration_siglip.py#L27
|
||||||
#[derive(serde::Deserialize, Clone, Debug)]
|
#[derive(serde::Deserialize, Clone, Debug)]
|
||||||
pub struct TextConfig {
|
pub struct TextConfig {
|
||||||
|
#[serde(default = "default_text_vocab_size")]
|
||||||
pub vocab_size: usize,
|
pub vocab_size: usize,
|
||||||
|
#[serde(default = "default_text_hidden_size")]
|
||||||
pub hidden_size: usize,
|
pub hidden_size: usize,
|
||||||
|
#[serde(default = "default_text_intermediate_size")]
|
||||||
pub intermediate_size: usize,
|
pub intermediate_size: usize,
|
||||||
|
#[serde(default = "default_text_num_hidden_layers")]
|
||||||
pub num_hidden_layers: usize,
|
pub num_hidden_layers: usize,
|
||||||
|
#[serde(default = "default_text_num_attention_heads")]
|
||||||
pub num_attention_heads: usize,
|
pub num_attention_heads: usize,
|
||||||
|
#[serde(default = "default_text_max_position_embeddings")]
|
||||||
pub max_position_embeddings: usize,
|
pub max_position_embeddings: usize,
|
||||||
|
#[serde(default = "default_text_hidden_act")]
|
||||||
pub hidden_act: candle_nn::Activation,
|
pub hidden_act: candle_nn::Activation,
|
||||||
|
#[serde(default = "default_text_layer_norm_eps")]
|
||||||
pub layer_norm_eps: f64,
|
pub layer_norm_eps: f64,
|
||||||
|
#[serde(default = "default_text_pad_token_id")]
|
||||||
pub pad_token_id: u32,
|
pub pad_token_id: u32,
|
||||||
|
#[serde(default = "default_text_bos_token_id")]
|
||||||
pub bos_token_id: u32,
|
pub bos_token_id: u32,
|
||||||
|
#[serde(default = "default_text_eos_token_id")]
|
||||||
pub eos_token_id: u32,
|
pub eos_token_id: u32,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn default_vision_hidden_size() -> usize {
|
||||||
|
768
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_vision_intermediate_size() -> usize {
|
||||||
|
3072
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_vision_num_hidden_layers() -> usize {
|
||||||
|
12
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_vision_num_attention_heads() -> usize {
|
||||||
|
12
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_vision_num_channels() -> usize {
|
||||||
|
3
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_vision_image_size() -> usize {
|
||||||
|
224
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_vision_batch_size() -> usize {
|
||||||
|
16
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_vision_layer_norm_eps() -> f64 {
|
||||||
|
1e-6
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_vision_hidden_act() -> candle_nn::Activation {
|
||||||
|
candle_nn::Activation::GeluPytorchTanh
|
||||||
|
}
|
||||||
|
|
||||||
// https://github.com/huggingface/transformers/blob/2e24ee4dfa39cc0bc264b89edbccc373c8337086/src/transformers/models/siglip/configuration_siglip.py#L132
|
// https://github.com/huggingface/transformers/blob/2e24ee4dfa39cc0bc264b89edbccc373c8337086/src/transformers/models/siglip/configuration_siglip.py#L132
|
||||||
#[derive(serde::Deserialize, Clone, Debug)]
|
#[derive(serde::Deserialize, Clone, Debug)]
|
||||||
pub struct VisionConfig {
|
pub struct VisionConfig {
|
||||||
|
#[serde(default = "default_vision_hidden_size")]
|
||||||
pub hidden_size: usize,
|
pub hidden_size: usize,
|
||||||
|
#[serde(default = "default_vision_intermediate_size")]
|
||||||
pub intermediate_size: usize,
|
pub intermediate_size: usize,
|
||||||
|
#[serde(default = "default_vision_num_hidden_layers")]
|
||||||
pub num_hidden_layers: usize,
|
pub num_hidden_layers: usize,
|
||||||
|
#[serde(default = "default_vision_num_attention_heads")]
|
||||||
pub num_attention_heads: usize,
|
pub num_attention_heads: usize,
|
||||||
|
#[serde(default = "default_vision_num_channels")]
|
||||||
pub num_channels: usize,
|
pub num_channels: usize,
|
||||||
|
#[serde(default = "default_vision_image_size")]
|
||||||
pub image_size: usize,
|
pub image_size: usize,
|
||||||
|
#[serde(default = "default_vision_batch_size")]
|
||||||
pub patch_size: usize,
|
pub patch_size: usize,
|
||||||
|
#[serde(default = "default_vision_hidden_act")]
|
||||||
pub hidden_act: candle_nn::Activation,
|
pub hidden_act: candle_nn::Activation,
|
||||||
|
#[serde(default = "default_vision_layer_norm_eps")]
|
||||||
pub layer_norm_eps: f64,
|
pub layer_norm_eps: f64,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user