diff --git a/Cargo.toml b/Cargo.toml index f86508d9..67094ac6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -66,7 +66,7 @@ serde = { version = "1.0.171", features = ["derive"] } serde_plain = "1.0.2" serde_json = "1.0.99" thiserror = "1" -tokenizers = { version = "0.19.1", default-features = false } +tokenizers = { version = "0.21.0", default-features = false } tracing = "0.1.37" tracing-chrome = "0.7.1" tracing-subscriber = "0.3.7" diff --git a/candle-examples/examples/siglip/main.rs b/candle-examples/examples/siglip/main.rs index bdd8f096..a78ed7f5 100644 --- a/candle-examples/examples/siglip/main.rs +++ b/candle-examples/examples/siglip/main.rs @@ -13,11 +13,40 @@ use candle_transformers::models::siglip; use tokenizers::Tokenizer; +#[derive(Clone, Copy, Debug, clap::ValueEnum, PartialEq, Eq)] +enum Which { + #[value(name = "v1-base-patch16-224")] + V1BasePatch16_224, + #[value(name = "v2-base-patch16-224")] + V2BasePatch16_224, + #[value(name = "v2-base-patch16-256")] + V2BasePatch16_256, + #[value(name = "v2-base-patch16-384")] + V2BasePatch16_384, + #[value(name = "v2-base-patch16-512")] + V2BasePatch16_512, + #[value(name = "v2-large-patch16-256")] + V2LargePatch16_256, + #[value(name = "v2-large-patch16-384")] + V2LargePatch16_384, + #[value(name = "v2-large-patch16-512")] + V2LargePatch16_512, +} + #[derive(Parser)] struct Args { #[arg(long)] model: Option, + #[arg(long)] + config: Option, + + #[arg(long)] + hf_repo: Option, + + #[arg(long, default_value = "v1-base-patch16-224")] + which: Which, + #[arg(long)] tokenizer: Option, @@ -66,16 +95,37 @@ fn load_images>( pub fn main() -> anyhow::Result<()> { let args = Args::parse(); + let hf_repo = match args.hf_repo.as_ref() { + Some(hf_repo) => hf_repo, + None => match args.which { + Which::V1BasePatch16_224 => "google/siglip-base-patch16-224", + Which::V2BasePatch16_224 => "google/siglip2-base-patch16-224", + Which::V2BasePatch16_256 => "google/siglip2-base-patch16-256", + Which::V2BasePatch16_384 => "google/siglip2-base-patch16-384", + Which::V2BasePatch16_512 => "google/siglip2-base-patch16-512", + Which::V2LargePatch16_256 => "google/siglip2-large-patch16-256", + Which::V2LargePatch16_384 => "google/siglip2-large-patch16-384", + Which::V2LargePatch16_512 => "google/siglip2-large-patch16-512", + }, + }; let model_file = match args.model { None => { let api = hf_hub::api::sync::Api::new()?; - let api = api.model("google/siglip-base-patch16-224".to_string()); + let api = api.model(hf_repo.to_string()); api.get("model.safetensors")? } Some(model) => model.into(), }; - let tokenizer = get_tokenizer(args.tokenizer)?; - let config = siglip::Config::base_patch16_224(); + let config_file = match args.config { + None => { + let api = hf_hub::api::sync::Api::new()?; + let api = api.model(hf_repo.to_string()); + api.get("config.json")? + } + Some(config) => config.into(), + }; + let tokenizer = get_tokenizer(hf_repo, args.tokenizer)?; + let config: siglip::Config = serde_json::from_slice(&std::fs::read(config_file)?)?; let device = candle_examples::device(args.cpu)?; let vec_imgs = match args.images { Some(imgs) => imgs, @@ -114,11 +164,11 @@ pub fn main() -> anyhow::Result<()> { Ok(()) } -pub fn get_tokenizer(tokenizer: Option) -> anyhow::Result { +pub fn get_tokenizer(hf_repo: &str, tokenizer: Option) -> anyhow::Result { let tokenizer = match tokenizer { None => { let api = hf_hub::api::sync::Api::new()?; - let api = api.model("google/siglip-base-patch16-224".to_string()); + let api = api.model(hf_repo.to_string()); api.get("tokenizer.json")? } Some(file) => file.into(), diff --git a/candle-transformers/src/models/siglip.rs b/candle-transformers/src/models/siglip.rs index b023c31f..578beea3 100644 --- a/candle-transformers/src/models/siglip.rs +++ b/candle-transformers/src/models/siglip.rs @@ -10,33 +10,133 @@ use crate::models::clip::div_l2_norm; use candle::{IndexOp, Module, Result, Tensor, D}; use candle_nn::{layer_norm, linear, LayerNorm, Linear, VarBuilder}; +fn default_text_vocab_size() -> usize { + 32000 +} + +fn default_text_hidden_size() -> usize { + 768 +} + +fn default_text_intermediate_size() -> usize { + 3072 +} + +fn default_text_num_hidden_layers() -> usize { + 12 +} + +fn default_text_num_attention_heads() -> usize { + 12 +} + +fn default_text_max_position_embeddings() -> usize { + 64 +} + +fn default_text_layer_norm_eps() -> f64 { + 1e-6 +} + +fn default_text_pad_token_id() -> u32 { + 1 +} + +fn default_text_bos_token_id() -> u32 { + 49406 +} + +fn default_text_eos_token_id() -> u32 { + 49407 +} + +fn default_text_hidden_act() -> candle_nn::Activation { + candle_nn::Activation::GeluPytorchTanh +} + // https://github.com/huggingface/transformers/blob/2e24ee4dfa39cc0bc264b89edbccc373c8337086/src/transformers/models/siglip/configuration_siglip.py#L27 #[derive(serde::Deserialize, Clone, Debug)] pub struct TextConfig { + #[serde(default = "default_text_vocab_size")] pub vocab_size: usize, + #[serde(default = "default_text_hidden_size")] pub hidden_size: usize, + #[serde(default = "default_text_intermediate_size")] pub intermediate_size: usize, + #[serde(default = "default_text_num_hidden_layers")] pub num_hidden_layers: usize, + #[serde(default = "default_text_num_attention_heads")] pub num_attention_heads: usize, + #[serde(default = "default_text_max_position_embeddings")] pub max_position_embeddings: usize, + #[serde(default = "default_text_hidden_act")] pub hidden_act: candle_nn::Activation, + #[serde(default = "default_text_layer_norm_eps")] pub layer_norm_eps: f64, + #[serde(default = "default_text_pad_token_id")] pub pad_token_id: u32, + #[serde(default = "default_text_bos_token_id")] pub bos_token_id: u32, + #[serde(default = "default_text_eos_token_id")] pub eos_token_id: u32, } +fn default_vision_hidden_size() -> usize { + 768 +} + +fn default_vision_intermediate_size() -> usize { + 3072 +} + +fn default_vision_num_hidden_layers() -> usize { + 12 +} + +fn default_vision_num_attention_heads() -> usize { + 12 +} + +fn default_vision_num_channels() -> usize { + 3 +} + +fn default_vision_image_size() -> usize { + 224 +} + +fn default_vision_batch_size() -> usize { + 16 +} + +fn default_vision_layer_norm_eps() -> f64 { + 1e-6 +} + +fn default_vision_hidden_act() -> candle_nn::Activation { + candle_nn::Activation::GeluPytorchTanh +} + // https://github.com/huggingface/transformers/blob/2e24ee4dfa39cc0bc264b89edbccc373c8337086/src/transformers/models/siglip/configuration_siglip.py#L132 #[derive(serde::Deserialize, Clone, Debug)] pub struct VisionConfig { + #[serde(default = "default_vision_hidden_size")] pub hidden_size: usize, + #[serde(default = "default_vision_intermediate_size")] pub intermediate_size: usize, + #[serde(default = "default_vision_num_hidden_layers")] pub num_hidden_layers: usize, + #[serde(default = "default_vision_num_attention_heads")] pub num_attention_heads: usize, + #[serde(default = "default_vision_num_channels")] pub num_channels: usize, + #[serde(default = "default_vision_image_size")] pub image_size: usize, + #[serde(default = "default_vision_batch_size")] pub patch_size: usize, + #[serde(default = "default_vision_hidden_act")] pub hidden_act: candle_nn::Activation, + #[serde(default = "default_vision_layer_norm_eps")] pub layer_norm_eps: f64, }