mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 04:10:46 +00:00

* Add Llama 3.1 rope * Clippy * Format * Clippy * Add support for multiple eos tokens: * Untagged either * Remove either dep and fix settings.json * Make the max positional embeddings configurable
270 lines
8.3 KiB
Rust
270 lines
8.3 KiB
Rust
use std::collections::HashMap;
|
|
|
|
use crate::models::{
|
|
clip::{text_model::Activation, vision_model::ClipVisionConfig},
|
|
llama::{Config, LlamaEosToks},
|
|
};
|
|
use serde::{Deserialize, Serialize};
|
|
|
|
// original config from liuhaotian/llava
|
|
#[derive(Serialize, Deserialize, Debug, Clone)]
|
|
pub struct LLaVAConfig {
|
|
pub architectures: Vec<String>,
|
|
pub bos_token_id: usize,
|
|
pub eos_token_id: usize,
|
|
pub hidden_size: usize,
|
|
#[serde(default = "default_image_aspect_ratio")]
|
|
pub image_aspect_ratio: String,
|
|
pub image_crop_resolution: usize,
|
|
pub image_grid_pinpoints: Vec<(u32, u32)>,
|
|
pub image_split_resolution: usize,
|
|
pub intermediate_size: usize,
|
|
pub max_position_embeddings: usize,
|
|
pub mm_hidden_size: usize,
|
|
#[serde(default = "default_mm_patch_merge_type")]
|
|
pub mm_patch_merge_type: String,
|
|
pub mm_projector_type: String,
|
|
pub mm_use_im_start_end: bool,
|
|
pub mm_vision_select_feature: String,
|
|
pub mm_vision_select_layer: isize,
|
|
pub mm_vision_tower: Option<String>,
|
|
pub model_type: String,
|
|
pub num_attention_heads: usize,
|
|
pub num_hidden_layers: usize,
|
|
pub num_key_value_heads: usize,
|
|
pub pad_token_id: usize,
|
|
pub rms_norm_eps: f32,
|
|
pub rope_theta: f32,
|
|
pub tokenizer_model_max_length: Option<usize>,
|
|
pub torch_dtype: String,
|
|
pub use_cache: bool,
|
|
pub vocab_size: usize,
|
|
#[serde(default = "default_image_token_index")]
|
|
pub image_token_index: isize,
|
|
#[serde(default = "default_hf")]
|
|
pub hf: bool,
|
|
}
|
|
|
|
fn default_hf() -> bool {
|
|
false
|
|
}
|
|
|
|
fn default_image_token_index() -> isize {
|
|
-200
|
|
}
|
|
|
|
fn default_mm_patch_merge_type() -> String {
|
|
"flat".to_string()
|
|
}
|
|
|
|
fn default_image_aspect_ratio() -> String {
|
|
"square".to_string()
|
|
}
|
|
|
|
impl LLaVAConfig {
|
|
pub fn to_llama_config(&self) -> Config {
|
|
Config {
|
|
hidden_size: self.hidden_size,
|
|
intermediate_size: self.intermediate_size,
|
|
vocab_size: self.vocab_size,
|
|
num_hidden_layers: self.num_hidden_layers,
|
|
num_attention_heads: self.num_attention_heads,
|
|
num_key_value_heads: self.num_key_value_heads,
|
|
rms_norm_eps: self.rms_norm_eps as f64,
|
|
rope_theta: self.rope_theta,
|
|
bos_token_id: Some(self.bos_token_id as u32),
|
|
eos_token_id: Some(LlamaEosToks::Single(self.eos_token_id as u32)),
|
|
use_flash_attn: false,
|
|
rope_scaling: None, // Assume we don't have LLaVA for Llama 3.1
|
|
max_position_embeddings: self.max_position_embeddings,
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize, Debug, Clone)]
|
|
pub struct HFLLaVATextConfig {
|
|
pub architectures: Vec<String>,
|
|
#[serde(default = "default_hidden_size")]
|
|
pub hidden_size: usize,
|
|
#[serde(default = "default_intermediate_size")]
|
|
pub intermediate_size: usize,
|
|
#[serde(default = "default_max_length")]
|
|
pub max_length: usize,
|
|
pub max_position_embeddings: usize,
|
|
pub model_type: String,
|
|
#[serde(default = "default_num_attention_heads")]
|
|
pub num_attention_heads: usize,
|
|
#[serde(default = "default_num_hidden_layers")]
|
|
pub num_hidden_layers: usize,
|
|
#[serde(default = "default_num_key_value_heads")]
|
|
pub num_key_value_heads: usize,
|
|
pub pad_token_id: usize,
|
|
pub rms_norm_eps: f32,
|
|
#[serde(default = "default_rope_theta")]
|
|
pub rope_theta: f32,
|
|
pub torch_dtype: String,
|
|
#[serde(default = "default_use_cache")]
|
|
pub use_cache: bool,
|
|
pub vocab_size: usize,
|
|
}
|
|
|
|
fn default_num_hidden_layers() -> usize {
|
|
32
|
|
}
|
|
|
|
fn default_use_cache() -> bool {
|
|
true
|
|
}
|
|
|
|
fn default_hidden_size() -> usize {
|
|
4096
|
|
}
|
|
|
|
fn default_intermediate_size() -> usize {
|
|
11008
|
|
}
|
|
|
|
fn default_max_length() -> usize {
|
|
4096
|
|
}
|
|
|
|
fn default_num_attention_heads() -> usize {
|
|
32
|
|
}
|
|
|
|
fn default_num_key_value_heads() -> usize {
|
|
32
|
|
}
|
|
|
|
fn default_rope_theta() -> f32 {
|
|
10000.0
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize, Debug, Clone)]
|
|
pub struct HFLLaVAVisionConfig {
|
|
pub hidden_size: usize,
|
|
pub image_size: usize,
|
|
pub intermediate_size: usize,
|
|
pub model_type: String,
|
|
pub num_attention_heads: usize,
|
|
pub num_hidden_layers: usize,
|
|
pub patch_size: usize,
|
|
pub projection_dim: usize,
|
|
pub vocab_size: usize,
|
|
}
|
|
|
|
// config from llava-v1.6-vicuna-7b-hf
|
|
#[derive(Serialize, Deserialize, Debug, Clone)]
|
|
pub struct HFLLaVAConfig {
|
|
pub architectures: Vec<String>,
|
|
pub ignore_index: isize,
|
|
pub image_grid_pinpoints: Vec<(u32, u32)>,
|
|
pub image_token_index: isize,
|
|
pub model_type: String,
|
|
pub projector_hidden_act: String,
|
|
pub text_config: HFLLaVATextConfig,
|
|
pub torch_dtype: String,
|
|
pub use_image_newline_parameter: bool,
|
|
pub vision_config: HFLLaVAVisionConfig,
|
|
pub vision_feature_layer: isize,
|
|
pub vision_feature_select_strategy: String,
|
|
pub vocab_size: usize,
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize, Debug, Clone)]
|
|
pub struct HFGenerationConfig {
|
|
pub bos_token_id: usize,
|
|
pub eos_token_id: usize,
|
|
#[serde(default = "default_max_length")]
|
|
pub max_length: usize,
|
|
pub pad_token_id: usize,
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize, Debug, Clone)]
|
|
pub struct HFPreProcessorConfig {
|
|
pub aspect_ratio_setting: String,
|
|
pub crop_size: HashMap<String, usize>,
|
|
pub do_center_crop: bool,
|
|
pub do_convert_rgb: bool,
|
|
pub do_normalize: bool,
|
|
pub do_rescale: bool,
|
|
pub do_resize: bool,
|
|
pub image_mean: Vec<f32>,
|
|
pub image_std: Vec<f32>,
|
|
pub resample: u32,
|
|
pub rescale_factor: f32,
|
|
pub size: HashMap<String, f32>,
|
|
}
|
|
|
|
impl HFLLaVAConfig {
|
|
pub fn to_clip_vision_config(&self) -> ClipVisionConfig {
|
|
ClipVisionConfig {
|
|
embed_dim: self.vision_config.hidden_size,
|
|
activation: Activation::QuickGelu,
|
|
intermediate_size: self.vision_config.intermediate_size,
|
|
num_hidden_layers: self.vision_config.num_hidden_layers,
|
|
num_attention_heads: self.vision_config.num_attention_heads,
|
|
projection_dim: self.vision_config.projection_dim,
|
|
num_channels: 3,
|
|
image_size: self.vision_config.image_size,
|
|
patch_size: self.vision_config.patch_size,
|
|
}
|
|
}
|
|
fn map_projector_type(s: &str) -> String {
|
|
if s == "gelu" {
|
|
"mlp2x_gelu".to_string()
|
|
} else {
|
|
s.to_string()
|
|
}
|
|
}
|
|
|
|
fn map_select_feature(s: &str) -> String {
|
|
if s == "default" {
|
|
"patch".to_string()
|
|
} else {
|
|
"cls_patch".to_string()
|
|
}
|
|
}
|
|
|
|
pub fn to_llava_config(
|
|
&self,
|
|
generation_config: &HFGenerationConfig,
|
|
preprocessor_config: &HFPreProcessorConfig,
|
|
) -> LLaVAConfig {
|
|
LLaVAConfig {
|
|
hf: true,
|
|
architectures: self.architectures.clone(),
|
|
bos_token_id: generation_config.bos_token_id,
|
|
eos_token_id: generation_config.eos_token_id,
|
|
hidden_size: self.text_config.hidden_size,
|
|
image_aspect_ratio: preprocessor_config.aspect_ratio_setting.clone(),
|
|
image_crop_resolution: 224,
|
|
image_grid_pinpoints: self.image_grid_pinpoints.clone(),
|
|
image_split_resolution: 224,
|
|
intermediate_size: self.text_config.intermediate_size,
|
|
max_position_embeddings: self.text_config.max_position_embeddings,
|
|
mm_hidden_size: 1024,
|
|
mm_patch_merge_type: "spatial_unpad".to_string(),
|
|
mm_projector_type: Self::map_projector_type(&self.projector_hidden_act),
|
|
mm_use_im_start_end: false,
|
|
mm_vision_select_feature: Self::map_select_feature(
|
|
&self.vision_feature_select_strategy,
|
|
),
|
|
mm_vision_select_layer: self.vision_feature_layer,
|
|
mm_vision_tower: None,
|
|
model_type: self.model_type.clone(),
|
|
num_attention_heads: self.text_config.num_attention_heads,
|
|
num_hidden_layers: self.text_config.num_hidden_layers,
|
|
num_key_value_heads: self.text_config.num_key_value_heads,
|
|
pad_token_id: self.text_config.pad_token_id,
|
|
rms_norm_eps: self.text_config.rms_norm_eps,
|
|
rope_theta: self.text_config.rope_theta,
|
|
tokenizer_model_max_length: Some(4096),
|
|
torch_dtype: self.torch_dtype.clone(),
|
|
use_cache: self.text_config.use_cache,
|
|
vocab_size: self.vocab_size,
|
|
image_token_index: self.image_token_index,
|
|
}
|
|
}
|
|
}
|