mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Add LLaVA support (#2234)
* first commit * llava * clippy and fmt * some fixes * minor fixes * remove useless file * refactor: Remove llava/constants.rs and update llava/mod.rs * modify variable name * modify code after clippy * Minor tweaks. --------- Co-authored-by: laurent <laurent.mazare@gmail.com>
This commit is contained in:
@ -262,6 +262,20 @@ impl ClipEncoder {
|
||||
}
|
||||
Ok(xs)
|
||||
}
|
||||
// required by LLaVA
|
||||
pub fn output_hidden_states(
|
||||
&self,
|
||||
xs: &Tensor,
|
||||
causal_attention_mask: Option<&Tensor>,
|
||||
) -> Result<Vec<Tensor>> {
|
||||
let mut xs = xs.clone();
|
||||
let mut hidden_states = Vec::new();
|
||||
for layer in self.layers.iter() {
|
||||
xs = layer.forward(&xs, causal_attention_mask)?;
|
||||
hidden_states.push(xs.clone());
|
||||
}
|
||||
Ok(hidden_states)
|
||||
}
|
||||
}
|
||||
|
||||
/// A CLIP transformer based model.
|
||||
|
@ -46,6 +46,19 @@ impl ClipVisionConfig {
|
||||
patch_size: 32,
|
||||
}
|
||||
}
|
||||
pub fn clip_vit_large_patch14_336() -> Self {
|
||||
Self {
|
||||
embed_dim: 1024,
|
||||
activation: Activation::QuickGelu,
|
||||
intermediate_size: 4096,
|
||||
num_hidden_layers: 24,
|
||||
num_attention_heads: 16,
|
||||
projection_dim: 768,
|
||||
num_channels: 3,
|
||||
image_size: 336,
|
||||
patch_size: 14,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/huggingface/transformers/blob/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip/modeling_clip.py#L112
|
||||
@ -130,6 +143,17 @@ impl ClipVisionTransformer {
|
||||
pre_layer_norm,
|
||||
})
|
||||
}
|
||||
// required by LLaVA
|
||||
pub fn output_hidden_states(&self, pixel_values: &Tensor) -> Result<Vec<Tensor>> {
|
||||
let hidden_states = pixel_values
|
||||
.apply(&self.embeddings)?
|
||||
.apply(&self.pre_layer_norm)?;
|
||||
let mut result = self.encoder.output_hidden_states(&hidden_states, None)?;
|
||||
let encoder_outputs = result.last().unwrap();
|
||||
let pooled_output = encoder_outputs.i((.., 0, ..))?;
|
||||
result.push(self.final_layer_norm.forward(&pooled_output)?.clone());
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for ClipVisionTransformer {
|
||||
|
@ -388,6 +388,28 @@ pub struct Llama {
|
||||
}
|
||||
|
||||
impl Llama {
|
||||
// required by LLaVA
|
||||
pub fn embed(&self, x: &Tensor) -> Result<Tensor> {
|
||||
self.wte.forward(x)
|
||||
}
|
||||
// required by LLaVA
|
||||
pub fn forward_input_embed(
|
||||
&self,
|
||||
input_embed: &Tensor,
|
||||
index_pos: usize,
|
||||
cache: &mut Cache,
|
||||
) -> Result<Tensor> {
|
||||
let (_, seq_len, _) = input_embed.dims3()?;
|
||||
let mut x = input_embed.clone();
|
||||
for (block_idx, block) in self.blocks.iter().enumerate() {
|
||||
x = block.forward(&x, index_pos, block_idx, cache)?;
|
||||
}
|
||||
let x = self.ln_f.forward(&x)?;
|
||||
let x = x.i((.., seq_len - 1, ..))?.contiguous()?;
|
||||
let logits = self.lm_head.forward(&x)?;
|
||||
logits.to_dtype(DType::F32)
|
||||
}
|
||||
|
||||
pub fn forward(&self, x: &Tensor, index_pos: usize, cache: &mut Cache) -> Result<Tensor> {
|
||||
let (_b_sz, seq_len) = x.dims2()?;
|
||||
let mut x = self.wte.forward(x)?;
|
||||
|
267
candle-transformers/src/models/llava/config.rs
Normal file
267
candle-transformers/src/models/llava/config.rs
Normal file
@ -0,0 +1,267 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::models::{
|
||||
clip::{text_model::Activation, vision_model::ClipVisionConfig},
|
||||
llama::Config,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
// original config from liuhaotian/llava
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct LLaVAConfig {
|
||||
pub architectures: Vec<String>,
|
||||
pub bos_token_id: usize,
|
||||
pub eos_token_id: usize,
|
||||
pub hidden_size: usize,
|
||||
#[serde(default = "default_image_aspect_ratio")]
|
||||
pub image_aspect_ratio: String,
|
||||
pub image_crop_resolution: usize,
|
||||
pub image_grid_pinpoints: Vec<(u32, u32)>,
|
||||
pub image_split_resolution: usize,
|
||||
pub intermediate_size: usize,
|
||||
pub max_position_embeddings: usize,
|
||||
pub mm_hidden_size: usize,
|
||||
#[serde(default = "default_mm_patch_merge_type")]
|
||||
pub mm_patch_merge_type: String,
|
||||
pub mm_projector_type: String,
|
||||
pub mm_use_im_start_end: bool,
|
||||
pub mm_vision_select_feature: String,
|
||||
pub mm_vision_select_layer: isize,
|
||||
pub mm_vision_tower: Option<String>,
|
||||
pub model_type: String,
|
||||
pub num_attention_heads: usize,
|
||||
pub num_hidden_layers: usize,
|
||||
pub num_key_value_heads: usize,
|
||||
pub pad_token_id: usize,
|
||||
pub rms_norm_eps: f32,
|
||||
pub rope_theta: f32,
|
||||
pub tokenizer_model_max_length: Option<usize>,
|
||||
pub torch_dtype: String,
|
||||
pub use_cache: bool,
|
||||
pub vocab_size: usize,
|
||||
#[serde(default = "default_image_token_index")]
|
||||
pub image_token_index: isize,
|
||||
#[serde(default = "default_hf")]
|
||||
pub hf: bool,
|
||||
}
|
||||
|
||||
fn default_hf() -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn default_image_token_index() -> isize {
|
||||
-200
|
||||
}
|
||||
|
||||
fn default_mm_patch_merge_type() -> String {
|
||||
"flat".to_string()
|
||||
}
|
||||
|
||||
fn default_image_aspect_ratio() -> String {
|
||||
"square".to_string()
|
||||
}
|
||||
|
||||
impl LLaVAConfig {
|
||||
pub fn to_llama_config(&self) -> Config {
|
||||
Config {
|
||||
hidden_size: self.hidden_size,
|
||||
intermediate_size: self.intermediate_size,
|
||||
vocab_size: self.vocab_size,
|
||||
num_hidden_layers: self.num_hidden_layers,
|
||||
num_attention_heads: self.num_attention_heads,
|
||||
num_key_value_heads: self.num_key_value_heads,
|
||||
rms_norm_eps: self.rms_norm_eps as f64,
|
||||
rope_theta: self.rope_theta,
|
||||
bos_token_id: Some(self.bos_token_id as u32),
|
||||
eos_token_id: Some(self.eos_token_id as u32),
|
||||
use_flash_attn: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct HFLLaVATextConfig {
|
||||
pub architectures: Vec<String>,
|
||||
#[serde(default = "default_hidden_size")]
|
||||
pub hidden_size: usize,
|
||||
#[serde(default = "default_intermediate_size")]
|
||||
pub intermediate_size: usize,
|
||||
#[serde(default = "default_max_length")]
|
||||
pub max_length: usize,
|
||||
pub max_position_embeddings: usize,
|
||||
pub model_type: String,
|
||||
#[serde(default = "default_num_attention_heads")]
|
||||
pub num_attention_heads: usize,
|
||||
#[serde(default = "default_num_hidden_layers")]
|
||||
pub num_hidden_layers: usize,
|
||||
#[serde(default = "default_num_key_value_heads")]
|
||||
pub num_key_value_heads: usize,
|
||||
pub pad_token_id: usize,
|
||||
pub rms_norm_eps: f32,
|
||||
#[serde(default = "default_rope_theta")]
|
||||
pub rope_theta: f32,
|
||||
pub torch_dtype: String,
|
||||
#[serde(default = "default_use_cache")]
|
||||
pub use_cache: bool,
|
||||
pub vocab_size: usize,
|
||||
}
|
||||
|
||||
fn default_num_hidden_layers() -> usize {
|
||||
32
|
||||
}
|
||||
|
||||
fn default_use_cache() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn default_hidden_size() -> usize {
|
||||
4096
|
||||
}
|
||||
|
||||
fn default_intermediate_size() -> usize {
|
||||
11008
|
||||
}
|
||||
|
||||
fn default_max_length() -> usize {
|
||||
4096
|
||||
}
|
||||
|
||||
fn default_num_attention_heads() -> usize {
|
||||
32
|
||||
}
|
||||
|
||||
fn default_num_key_value_heads() -> usize {
|
||||
32
|
||||
}
|
||||
|
||||
fn default_rope_theta() -> f32 {
|
||||
10000.0
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct HFLLaVAVisionConfig {
|
||||
pub hidden_size: usize,
|
||||
pub image_size: usize,
|
||||
pub intermediate_size: usize,
|
||||
pub model_type: String,
|
||||
pub num_attention_heads: usize,
|
||||
pub num_hidden_layers: usize,
|
||||
pub patch_size: usize,
|
||||
pub projection_dim: usize,
|
||||
pub vocab_size: usize,
|
||||
}
|
||||
|
||||
// config from llava-v1.6-vicuna-7b-hf
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct HFLLaVAConfig {
|
||||
pub architectures: Vec<String>,
|
||||
pub ignore_index: isize,
|
||||
pub image_grid_pinpoints: Vec<(u32, u32)>,
|
||||
pub image_token_index: isize,
|
||||
pub model_type: String,
|
||||
pub projector_hidden_act: String,
|
||||
pub text_config: HFLLaVATextConfig,
|
||||
pub torch_dtype: String,
|
||||
pub use_image_newline_parameter: bool,
|
||||
pub vision_config: HFLLaVAVisionConfig,
|
||||
pub vision_feature_layer: isize,
|
||||
pub vision_feature_select_strategy: String,
|
||||
pub vocab_size: usize,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct HFGenerationConfig {
|
||||
pub bos_token_id: usize,
|
||||
pub eos_token_id: usize,
|
||||
#[serde(default = "default_max_length")]
|
||||
pub max_length: usize,
|
||||
pub pad_token_id: usize,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct HFPreProcessorConfig {
|
||||
pub aspect_ratio_setting: String,
|
||||
pub crop_size: HashMap<String, usize>,
|
||||
pub do_center_crop: bool,
|
||||
pub do_convert_rgb: bool,
|
||||
pub do_normalize: bool,
|
||||
pub do_rescale: bool,
|
||||
pub do_resize: bool,
|
||||
pub image_mean: Vec<f32>,
|
||||
pub image_std: Vec<f32>,
|
||||
pub resample: u32,
|
||||
pub rescale_factor: f32,
|
||||
pub size: HashMap<String, f32>,
|
||||
}
|
||||
|
||||
impl HFLLaVAConfig {
|
||||
pub fn to_clip_vision_config(&self) -> ClipVisionConfig {
|
||||
ClipVisionConfig {
|
||||
embed_dim: self.vision_config.hidden_size,
|
||||
activation: Activation::QuickGelu,
|
||||
intermediate_size: self.vision_config.intermediate_size,
|
||||
num_hidden_layers: self.vision_config.num_hidden_layers,
|
||||
num_attention_heads: self.vision_config.num_attention_heads,
|
||||
projection_dim: self.vision_config.projection_dim,
|
||||
num_channels: 3,
|
||||
image_size: self.vision_config.image_size,
|
||||
patch_size: self.vision_config.patch_size,
|
||||
}
|
||||
}
|
||||
fn map_projector_type(s: &str) -> String {
|
||||
if s == "gelu" {
|
||||
"mlp2x_gelu".to_string()
|
||||
} else {
|
||||
s.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
fn map_select_feature(s: &str) -> String {
|
||||
if s == "default" {
|
||||
"patch".to_string()
|
||||
} else {
|
||||
"cls_patch".to_string()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_llava_config(
|
||||
&self,
|
||||
generation_config: &HFGenerationConfig,
|
||||
preprocessor_config: &HFPreProcessorConfig,
|
||||
) -> LLaVAConfig {
|
||||
LLaVAConfig {
|
||||
hf: true,
|
||||
architectures: self.architectures.clone(),
|
||||
bos_token_id: generation_config.bos_token_id,
|
||||
eos_token_id: generation_config.eos_token_id,
|
||||
hidden_size: self.text_config.hidden_size,
|
||||
image_aspect_ratio: preprocessor_config.aspect_ratio_setting.clone(),
|
||||
image_crop_resolution: 224,
|
||||
image_grid_pinpoints: self.image_grid_pinpoints.clone(),
|
||||
image_split_resolution: 224,
|
||||
intermediate_size: self.text_config.intermediate_size,
|
||||
max_position_embeddings: self.text_config.max_position_embeddings,
|
||||
mm_hidden_size: 1024,
|
||||
mm_patch_merge_type: "spatial_unpad".to_string(),
|
||||
mm_projector_type: Self::map_projector_type(&self.projector_hidden_act),
|
||||
mm_use_im_start_end: false,
|
||||
mm_vision_select_feature: Self::map_select_feature(
|
||||
&self.vision_feature_select_strategy,
|
||||
),
|
||||
mm_vision_select_layer: self.vision_feature_layer,
|
||||
mm_vision_tower: None,
|
||||
model_type: self.model_type.clone(),
|
||||
num_attention_heads: self.text_config.num_attention_heads,
|
||||
num_hidden_layers: self.text_config.num_hidden_layers,
|
||||
num_key_value_heads: self.text_config.num_key_value_heads,
|
||||
pad_token_id: self.text_config.pad_token_id,
|
||||
rms_norm_eps: self.text_config.rms_norm_eps,
|
||||
rope_theta: self.text_config.rope_theta,
|
||||
tokenizer_model_max_length: Some(4096),
|
||||
torch_dtype: self.torch_dtype.clone(),
|
||||
use_cache: self.text_config.use_cache,
|
||||
vocab_size: self.vocab_size,
|
||||
image_token_index: self.image_token_index,
|
||||
}
|
||||
}
|
||||
}
|
407
candle-transformers/src/models/llava/mod.rs
Normal file
407
candle-transformers/src/models/llava/mod.rs
Normal file
@ -0,0 +1,407 @@
|
||||
pub mod config;
|
||||
pub mod utils;
|
||||
|
||||
use crate::models::clip::vision_model::{ClipVisionConfig, ClipVisionTransformer};
|
||||
use crate::models::llama::{Cache, Llama};
|
||||
use crate::models::with_tracing::linear;
|
||||
|
||||
use candle::{bail, Device, IndexOp, Result, Tensor};
|
||||
use candle_nn::{seq, Activation, Module, Sequential, VarBuilder};
|
||||
use fancy_regex::Regex;
|
||||
use utils::get_anyres_image_grid_shape;
|
||||
|
||||
use config::LLaVAConfig;
|
||||
|
||||
fn mlp_gelu_match(mm_projector_type: &str) -> Option<usize> {
|
||||
let mlp_gelu_regex = Regex::new(r"^mlp(\d+)x_gelu$").unwrap();
|
||||
|
||||
if let Ok(Some(captures)) = mlp_gelu_regex.captures(mm_projector_type) {
|
||||
if let Some(match_str) = captures.get(1) {
|
||||
let match_str = match_str.as_str();
|
||||
match_str.parse::<usize>().ok()
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn unpad_image(tensor: &Tensor, original_size: &(u32, u32)) -> Result<Tensor> {
|
||||
assert_eq!(tensor.dims().len(), 3);
|
||||
let (original_width, original_height) = *original_size;
|
||||
let tensor_dims = tensor.dims();
|
||||
let current_height = tensor_dims[1];
|
||||
let current_width = tensor_dims[2];
|
||||
let original_aspect_ratio = (original_width as f32) / (original_height as f32);
|
||||
let current_aspect_ratio = (current_width as f32) / (current_height as f32);
|
||||
if original_aspect_ratio > current_aspect_ratio {
|
||||
let scale_factor = (current_width as f32) / (original_width as f32);
|
||||
let new_height = (original_height as f32 * scale_factor).floor() as usize;
|
||||
let padding = (current_height - new_height) / 2;
|
||||
tensor.i((.., padding..current_width - padding, ..))
|
||||
} else {
|
||||
let scale_factor = (current_height as f32) / (original_height as f32);
|
||||
let new_width = (original_width as f32 * scale_factor).floor() as usize;
|
||||
let padding = (current_width - new_width) / 2;
|
||||
tensor.i((.., .., padding..current_width - padding))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct IdentityMap {}
|
||||
|
||||
impl Module for IdentityMap {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
Ok(x.clone())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct MMProjector {
|
||||
pub modules: Sequential,
|
||||
}
|
||||
|
||||
impl MMProjector {
|
||||
pub fn load(vb: &VarBuilder, config: &LLaVAConfig) -> Result<Self> {
|
||||
if config.mm_projector_type == "linear" {
|
||||
let vb_prefix = if config.hf {
|
||||
"multi_modal_projector.linear_1"
|
||||
} else {
|
||||
"model.mm_projector.0"
|
||||
};
|
||||
let linear = linear(config.mm_hidden_size, config.hidden_size, vb.pp(vb_prefix))?;
|
||||
let modules = seq().add(linear);
|
||||
Ok(Self { modules })
|
||||
} else if let Some(mlp_depth) = mlp_gelu_match(&config.mm_projector_type) {
|
||||
let modules = if config.hf {
|
||||
let mut modules = seq().add(linear(
|
||||
config.mm_hidden_size,
|
||||
config.hidden_size,
|
||||
vb.pp("multi_modal_projector.linear_1"),
|
||||
)?);
|
||||
for i in 1..mlp_depth {
|
||||
modules = modules.add(Activation::Gelu).add(linear(
|
||||
config.hidden_size,
|
||||
config.hidden_size,
|
||||
vb.pp(format!("multi_modal_projector.linear_{}", i + 1)),
|
||||
)?);
|
||||
}
|
||||
modules
|
||||
} else {
|
||||
let mut modules = seq().add(linear(
|
||||
config.mm_hidden_size,
|
||||
config.hidden_size,
|
||||
vb.pp("model.mm_projector.0"),
|
||||
)?);
|
||||
for i in 1..mlp_depth {
|
||||
modules = modules.add(Activation::Gelu).add(linear(
|
||||
config.hidden_size,
|
||||
config.hidden_size,
|
||||
vb.pp(format!("model.mm_projector.{}", i * 2)),
|
||||
)?);
|
||||
}
|
||||
modules
|
||||
};
|
||||
Ok(Self { modules })
|
||||
} else if config.mm_projector_type == "identity" {
|
||||
Ok(Self {
|
||||
modules: seq().add(IdentityMap {}),
|
||||
})
|
||||
} else {
|
||||
bail!(
|
||||
"Unsupported MM projector type: {}",
|
||||
config.mm_projector_type
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
self.modules.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ClipVisionTower {
|
||||
model: ClipVisionTransformer,
|
||||
select_layer: isize,
|
||||
select_feature_method: String,
|
||||
pub config: ClipVisionConfig,
|
||||
}
|
||||
|
||||
impl ClipVisionTower {
|
||||
pub fn new(
|
||||
vb: VarBuilder,
|
||||
select_layer: isize,
|
||||
select_feature_method: &str,
|
||||
config: &Option<ClipVisionConfig>,
|
||||
) -> Result<Self> {
|
||||
let config = if config.is_none() {
|
||||
ClipVisionConfig::clip_vit_large_patch14_336()
|
||||
} else {
|
||||
config.clone().unwrap()
|
||||
};
|
||||
let select_layer = match select_layer {
|
||||
-1 | -2 => select_layer,
|
||||
_ => bail!("Unsupported select layer: {}", select_layer),
|
||||
};
|
||||
let model = ClipVisionTransformer::new(vb, &config)?;
|
||||
Ok(Self {
|
||||
model,
|
||||
select_layer,
|
||||
select_feature_method: select_feature_method.to_string(),
|
||||
config,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let result = self.model.output_hidden_states(x)?;
|
||||
let index = result.len() as isize + self.select_layer;
|
||||
let result = result[index as usize].clone();
|
||||
if self.select_feature_method == "cls_patch" {
|
||||
Ok(result)
|
||||
} else {
|
||||
result.i((.., 1..))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn num_patches_per_side(&self) -> usize {
|
||||
self.config.image_size / self.config.patch_size
|
||||
}
|
||||
}
|
||||
|
||||
pub struct LLaVA {
|
||||
pub clip_vision_tower: ClipVisionTower,
|
||||
pub image_newline: Tensor,
|
||||
pub mm_projector: MMProjector,
|
||||
pub llama: Llama,
|
||||
config: LLaVAConfig,
|
||||
device: Device,
|
||||
}
|
||||
|
||||
impl LLaVA {
|
||||
pub fn load(
|
||||
vb: VarBuilder,
|
||||
config: &LLaVAConfig,
|
||||
clip_vision_config: Option<ClipVisionConfig>,
|
||||
) -> Result<Self> {
|
||||
let device = vb.device().clone();
|
||||
let llama_config = config.to_llama_config();
|
||||
let mm_projector = MMProjector::load(&vb, config)?;
|
||||
let (clip_vision_tower, image_newline, llama) = if config.hf {
|
||||
(
|
||||
ClipVisionTower::new(
|
||||
vb.pp("vision_tower.vision_model"),
|
||||
config.mm_vision_select_layer,
|
||||
&config.mm_vision_select_feature,
|
||||
&clip_vision_config,
|
||||
)?,
|
||||
vb.get(&[config.hidden_size], "image_newline")?
|
||||
.to_device(&device)?,
|
||||
Llama::load(vb.pp("language_model"), &llama_config)?,
|
||||
)
|
||||
} else {
|
||||
(
|
||||
ClipVisionTower::new(
|
||||
vb.pp("model.vision_tower.vision_tower.vision_model"),
|
||||
config.mm_vision_select_layer,
|
||||
&config.mm_vision_select_feature,
|
||||
&clip_vision_config,
|
||||
)?,
|
||||
vb.get(&[config.hidden_size], "model.image_newline")?
|
||||
.to_device(&device)?,
|
||||
Llama::load(vb, &llama_config)?,
|
||||
)
|
||||
};
|
||||
Ok(Self {
|
||||
clip_vision_tower,
|
||||
image_newline,
|
||||
mm_projector,
|
||||
llama,
|
||||
config: (*config).clone(),
|
||||
device,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn encode_images(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let image_features = self.clip_vision_tower.forward(x)?;
|
||||
let image_features = self.mm_projector.forward(&image_features)?;
|
||||
Ok(image_features)
|
||||
}
|
||||
// currently only for single image, 4 dim tensor
|
||||
pub fn prepare_inputs_labels_for_multimodal(
|
||||
&self,
|
||||
input_ids: &Tensor,
|
||||
images: &[Tensor],
|
||||
image_sizes: &[(u32, u32)],
|
||||
) -> Result<Tensor> {
|
||||
//TODO: process of multiple images/ new line
|
||||
// 576: 336(input size)/14(patch size)=24 24*24+1(class)=577 577-1=576
|
||||
let concat_images = Tensor::cat(images, 0)?;
|
||||
let image_features_together = self.encode_images(&concat_images)?;
|
||||
let split_sizes = images
|
||||
.iter()
|
||||
.map(|x| x.shape().dims()[0])
|
||||
.collect::<Vec<usize>>();
|
||||
// can be replaced by split
|
||||
let mut index_pos = 0;
|
||||
let mut image_features = Vec::new();
|
||||
for split_size in split_sizes.iter() {
|
||||
image_features.push(image_features_together.i(index_pos..index_pos + (*split_size))?);
|
||||
index_pos += *split_size;
|
||||
}
|
||||
let mm_patch_merge_type = &self.config.mm_patch_merge_type;
|
||||
let image_aspect_ratio = &self.config.image_aspect_ratio;
|
||||
|
||||
let image_features = if mm_patch_merge_type == "flat" {
|
||||
image_features
|
||||
.iter()
|
||||
.map(|x| x.flatten(0, 1).unwrap())
|
||||
.collect::<Vec<Tensor>>()
|
||||
} else if mm_patch_merge_type.starts_with("spatial") {
|
||||
let mut new_image_features = Vec::new();
|
||||
for (image_idx, image_feature) in image_features.iter().enumerate() {
|
||||
let new_image_feature = if image_feature.dims()[0] > 1 {
|
||||
let base_image_feature = image_feature.get(0).unwrap();
|
||||
let patch_image_feature = image_feature.i(1..).unwrap();
|
||||
let height = self.clip_vision_tower.num_patches_per_side();
|
||||
let width = height;
|
||||
assert_eq!(height * width, base_image_feature.dims()[0]);
|
||||
let image_size = image_sizes[image_idx];
|
||||
let new_image_feature = if image_aspect_ratio == "anyres" {
|
||||
let (num_patch_width, num_patch_height) = get_anyres_image_grid_shape(
|
||||
image_size,
|
||||
&self.config.image_grid_pinpoints,
|
||||
self.clip_vision_tower.config.image_size as u32,
|
||||
);
|
||||
patch_image_feature.reshape((
|
||||
num_patch_height as usize,
|
||||
num_patch_width as usize,
|
||||
height,
|
||||
width,
|
||||
(),
|
||||
))?
|
||||
} else {
|
||||
todo!("not implemented in original python LLaVA yet")
|
||||
};
|
||||
let new_image_feature = if mm_patch_merge_type.contains("unpad") {
|
||||
let new_image_feature = new_image_feature
|
||||
.permute((4, 0, 2, 1, 3))?
|
||||
.flatten(1, 2)?
|
||||
.flatten(2, 3)?;
|
||||
let new_image_feature = unpad_image(&new_image_feature, &image_size)?;
|
||||
let new_image_feature_dims = new_image_feature.dims();
|
||||
let image_new_line = self
|
||||
.image_newline
|
||||
.reshape((self.config.hidden_size, 1, 1))?
|
||||
.broadcast_as((
|
||||
new_image_feature_dims[0],
|
||||
new_image_feature_dims[1],
|
||||
1,
|
||||
))?;
|
||||
let new_image_feature =
|
||||
Tensor::cat(&[new_image_feature, image_new_line], 2)?;
|
||||
new_image_feature.flatten(1, 2)?.transpose(0, 1)?
|
||||
} else {
|
||||
new_image_feature.permute((0, 2, 1, 3, 4))?.flatten(0, 3)?
|
||||
};
|
||||
Tensor::cat(&[base_image_feature, new_image_feature], 0)?
|
||||
} else {
|
||||
let new_image_feature = image_feature.get(0).unwrap();
|
||||
if mm_patch_merge_type.contains("unpad") {
|
||||
Tensor::cat(
|
||||
&[
|
||||
new_image_feature,
|
||||
self.image_newline.clone().unsqueeze(0).unwrap(),
|
||||
],
|
||||
0,
|
||||
)
|
||||
.unwrap()
|
||||
} else {
|
||||
new_image_feature
|
||||
}
|
||||
};
|
||||
new_image_features.push(new_image_feature);
|
||||
}
|
||||
new_image_features
|
||||
} else {
|
||||
bail!("Unexpected mm_patch_merge_type: {mm_patch_merge_type}")
|
||||
};
|
||||
// can easily be replaced by nonzero if it is implemented in candle
|
||||
let input_ids_vec = input_ids.squeeze(0)?.to_vec1::<i64>()?;
|
||||
let mut image_indices = {
|
||||
let mut image_indices = vec![0_i64];
|
||||
image_indices.extend(
|
||||
input_ids_vec
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter_map(|(i, x)| {
|
||||
if *x == self.config.image_token_index as i64 {
|
||||
Some(i as i64)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect::<Vec<i64>>(),
|
||||
);
|
||||
image_indices
|
||||
};
|
||||
if image_indices.len() == 1 {
|
||||
//no image, only [0],
|
||||
return self.llama.embed(input_ids);
|
||||
}
|
||||
|
||||
let input_ids_noim = input_ids_vec
|
||||
.iter()
|
||||
.filter_map(|x| {
|
||||
if *x != self.config.image_token_index as i64 {
|
||||
Some(*x)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect::<Vec<i64>>();
|
||||
let input_ids_noim_len = input_ids_noim.len();
|
||||
image_indices.push((input_ids_noim_len) as i64);
|
||||
let input_ids_noim = Tensor::from_vec(input_ids_noim, input_ids_noim_len, &self.device)?;
|
||||
let cur_input_embeds = self.llama.embed(&input_ids_noim)?;
|
||||
// can be replace by split if it is implemented in candle
|
||||
let input_embed_no_ims = {
|
||||
let mut input_embeds = Vec::new();
|
||||
for i in 0..image_indices.len() - 1 {
|
||||
let start = (image_indices[i]) as usize;
|
||||
let end = image_indices[i + 1] as usize;
|
||||
input_embeds.push(cur_input_embeds.i((start..end, ..))?)
|
||||
}
|
||||
input_embeds
|
||||
};
|
||||
|
||||
let mut cur_new_input_embeds = Vec::new();
|
||||
for (i, image_feature) in image_features.iter().enumerate() {
|
||||
cur_new_input_embeds.push(input_embed_no_ims[i].clone());
|
||||
cur_new_input_embeds.push(image_feature.clone());
|
||||
}
|
||||
cur_new_input_embeds.push(input_embed_no_ims[image_features.len()].clone());
|
||||
let new_input_embeds = Tensor::cat(&cur_new_input_embeds, 0)?;
|
||||
//trancate
|
||||
let new_input_embeds =
|
||||
if let Some(tokenizer_model_max_length) = self.config.tokenizer_model_max_length {
|
||||
let (new_input_embeds_length, _) = new_input_embeds.shape().dims2()?;
|
||||
if new_input_embeds_length > tokenizer_model_max_length {
|
||||
new_input_embeds.i((..tokenizer_model_max_length, ..))?
|
||||
} else {
|
||||
new_input_embeds
|
||||
}
|
||||
} else {
|
||||
new_input_embeds
|
||||
};
|
||||
new_input_embeds.unsqueeze(0)
|
||||
}
|
||||
|
||||
pub fn forward(
|
||||
&self,
|
||||
input_embeds: &Tensor,
|
||||
position_id: usize,
|
||||
cache: &mut Cache,
|
||||
) -> Result<Tensor> {
|
||||
self.llama
|
||||
.forward_input_embed(input_embeds, position_id, cache)
|
||||
}
|
||||
}
|
41
candle-transformers/src/models/llava/utils.rs
Normal file
41
candle-transformers/src/models/llava/utils.rs
Normal file
@ -0,0 +1,41 @@
|
||||
pub fn get_anyres_image_grid_shape(
|
||||
image_size: (u32, u32),
|
||||
grid_pinpoints: &[(u32, u32)],
|
||||
patch_size: u32,
|
||||
) -> (u32, u32) {
|
||||
let (width, height) = select_best_resolution(image_size, grid_pinpoints);
|
||||
(width / patch_size, height / patch_size)
|
||||
}
|
||||
|
||||
pub fn select_best_resolution(
|
||||
original_size: (u32, u32),
|
||||
possible_resolutions: &[(u32, u32)],
|
||||
) -> (u32, u32) {
|
||||
let (original_width, original_height) = original_size;
|
||||
let mut best_fit = (0, 0);
|
||||
let original_width_f = original_width as f32;
|
||||
let original_height_f = original_height as f32;
|
||||
let mut max_effective_resolution = 0_u32;
|
||||
let mut min_wasted_resolution = u32::MAX;
|
||||
for (width, height) in possible_resolutions {
|
||||
let width_f = *width as f32;
|
||||
let height_f = *height as f32;
|
||||
let scale = (width_f / original_width_f).min(height_f / original_height_f);
|
||||
let (downscaled_width, downscaled_height) = (
|
||||
(original_width_f * scale) as u32,
|
||||
(original_height_f * scale) as u32,
|
||||
);
|
||||
let effective_resolution =
|
||||
std::cmp::min((*width) * (*height), downscaled_width * downscaled_height);
|
||||
let wasted_resolution = (*width) * (*height) - effective_resolution;
|
||||
if effective_resolution > max_effective_resolution
|
||||
|| (effective_resolution == max_effective_resolution
|
||||
&& wasted_resolution < min_wasted_resolution)
|
||||
{
|
||||
best_fit = (*width, *height);
|
||||
max_effective_resolution = effective_resolution;
|
||||
min_wasted_resolution = wasted_resolution;
|
||||
}
|
||||
}
|
||||
best_fit
|
||||
}
|
@ -17,6 +17,7 @@ pub mod jina_bert;
|
||||
pub mod llama;
|
||||
pub mod llama2_c;
|
||||
pub mod llama2_c_weights;
|
||||
pub mod llava;
|
||||
pub mod mamba;
|
||||
pub mod marian;
|
||||
pub mod metavoice;
|
||||
|
Reference in New Issue
Block a user