mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Parse the json config for siglip models. (#2800)
* Parse the json config for siglip models. * Bump the tokenizers dependency. * Add a v2 model. * Support more v2 model.s
This commit is contained in:
@ -10,33 +10,133 @@ use crate::models::clip::div_l2_norm;
|
||||
use candle::{IndexOp, Module, Result, Tensor, D};
|
||||
use candle_nn::{layer_norm, linear, LayerNorm, Linear, VarBuilder};
|
||||
|
||||
fn default_text_vocab_size() -> usize {
|
||||
32000
|
||||
}
|
||||
|
||||
fn default_text_hidden_size() -> usize {
|
||||
768
|
||||
}
|
||||
|
||||
fn default_text_intermediate_size() -> usize {
|
||||
3072
|
||||
}
|
||||
|
||||
fn default_text_num_hidden_layers() -> usize {
|
||||
12
|
||||
}
|
||||
|
||||
fn default_text_num_attention_heads() -> usize {
|
||||
12
|
||||
}
|
||||
|
||||
fn default_text_max_position_embeddings() -> usize {
|
||||
64
|
||||
}
|
||||
|
||||
fn default_text_layer_norm_eps() -> f64 {
|
||||
1e-6
|
||||
}
|
||||
|
||||
fn default_text_pad_token_id() -> u32 {
|
||||
1
|
||||
}
|
||||
|
||||
fn default_text_bos_token_id() -> u32 {
|
||||
49406
|
||||
}
|
||||
|
||||
fn default_text_eos_token_id() -> u32 {
|
||||
49407
|
||||
}
|
||||
|
||||
fn default_text_hidden_act() -> candle_nn::Activation {
|
||||
candle_nn::Activation::GeluPytorchTanh
|
||||
}
|
||||
|
||||
// https://github.com/huggingface/transformers/blob/2e24ee4dfa39cc0bc264b89edbccc373c8337086/src/transformers/models/siglip/configuration_siglip.py#L27
|
||||
#[derive(serde::Deserialize, Clone, Debug)]
|
||||
pub struct TextConfig {
|
||||
#[serde(default = "default_text_vocab_size")]
|
||||
pub vocab_size: usize,
|
||||
#[serde(default = "default_text_hidden_size")]
|
||||
pub hidden_size: usize,
|
||||
#[serde(default = "default_text_intermediate_size")]
|
||||
pub intermediate_size: usize,
|
||||
#[serde(default = "default_text_num_hidden_layers")]
|
||||
pub num_hidden_layers: usize,
|
||||
#[serde(default = "default_text_num_attention_heads")]
|
||||
pub num_attention_heads: usize,
|
||||
#[serde(default = "default_text_max_position_embeddings")]
|
||||
pub max_position_embeddings: usize,
|
||||
#[serde(default = "default_text_hidden_act")]
|
||||
pub hidden_act: candle_nn::Activation,
|
||||
#[serde(default = "default_text_layer_norm_eps")]
|
||||
pub layer_norm_eps: f64,
|
||||
#[serde(default = "default_text_pad_token_id")]
|
||||
pub pad_token_id: u32,
|
||||
#[serde(default = "default_text_bos_token_id")]
|
||||
pub bos_token_id: u32,
|
||||
#[serde(default = "default_text_eos_token_id")]
|
||||
pub eos_token_id: u32,
|
||||
}
|
||||
|
||||
fn default_vision_hidden_size() -> usize {
|
||||
768
|
||||
}
|
||||
|
||||
fn default_vision_intermediate_size() -> usize {
|
||||
3072
|
||||
}
|
||||
|
||||
fn default_vision_num_hidden_layers() -> usize {
|
||||
12
|
||||
}
|
||||
|
||||
fn default_vision_num_attention_heads() -> usize {
|
||||
12
|
||||
}
|
||||
|
||||
fn default_vision_num_channels() -> usize {
|
||||
3
|
||||
}
|
||||
|
||||
fn default_vision_image_size() -> usize {
|
||||
224
|
||||
}
|
||||
|
||||
fn default_vision_batch_size() -> usize {
|
||||
16
|
||||
}
|
||||
|
||||
fn default_vision_layer_norm_eps() -> f64 {
|
||||
1e-6
|
||||
}
|
||||
|
||||
fn default_vision_hidden_act() -> candle_nn::Activation {
|
||||
candle_nn::Activation::GeluPytorchTanh
|
||||
}
|
||||
|
||||
// https://github.com/huggingface/transformers/blob/2e24ee4dfa39cc0bc264b89edbccc373c8337086/src/transformers/models/siglip/configuration_siglip.py#L132
|
||||
#[derive(serde::Deserialize, Clone, Debug)]
|
||||
pub struct VisionConfig {
|
||||
#[serde(default = "default_vision_hidden_size")]
|
||||
pub hidden_size: usize,
|
||||
#[serde(default = "default_vision_intermediate_size")]
|
||||
pub intermediate_size: usize,
|
||||
#[serde(default = "default_vision_num_hidden_layers")]
|
||||
pub num_hidden_layers: usize,
|
||||
#[serde(default = "default_vision_num_attention_heads")]
|
||||
pub num_attention_heads: usize,
|
||||
#[serde(default = "default_vision_num_channels")]
|
||||
pub num_channels: usize,
|
||||
#[serde(default = "default_vision_image_size")]
|
||||
pub image_size: usize,
|
||||
#[serde(default = "default_vision_batch_size")]
|
||||
pub patch_size: usize,
|
||||
#[serde(default = "default_vision_hidden_act")]
|
||||
pub hidden_act: candle_nn::Activation,
|
||||
#[serde(default = "default_vision_layer_norm_eps")]
|
||||
pub layer_norm_eps: f64,
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user