From cd4d941ed10fd334333cf5793e311d2bef88a438 Mon Sep 17 00:00:00 2001 From: chenwanqq Date: Mon, 3 Jun 2024 17:54:09 +0800 Subject: [PATCH] Add LLaVA support (#2234) * first commit * llava * clippy and fmt * some fixes * minor fixes * remove useless file * refactor: Remove llava/constants.rs and update llava/mod.rs * modify variable name * modify code after clippy * Minor tweaks. --------- Co-authored-by: laurent --- candle-examples/examples/llava/constants.rs | 4 + .../examples/llava/conversation.rs | 114 +++++ .../examples/llava/image_processor.rs | 317 ++++++++++++++ candle-examples/examples/llava/main.rs | 316 ++++++++++++++ candle-examples/examples/llava/readme.md | 40 ++ .../src/models/clip/text_model.rs | 14 + .../src/models/clip/vision_model.rs | 24 ++ candle-transformers/src/models/llama.rs | 22 + .../src/models/llava/config.rs | 267 ++++++++++++ candle-transformers/src/models/llava/mod.rs | 407 ++++++++++++++++++ candle-transformers/src/models/llava/utils.rs | 41 ++ candle-transformers/src/models/mod.rs | 1 + 12 files changed, 1567 insertions(+) create mode 100644 candle-examples/examples/llava/constants.rs create mode 100644 candle-examples/examples/llava/conversation.rs create mode 100644 candle-examples/examples/llava/image_processor.rs create mode 100644 candle-examples/examples/llava/main.rs create mode 100644 candle-examples/examples/llava/readme.md create mode 100644 candle-transformers/src/models/llava/config.rs create mode 100644 candle-transformers/src/models/llava/mod.rs create mode 100644 candle-transformers/src/models/llava/utils.rs diff --git a/candle-examples/examples/llava/constants.rs b/candle-examples/examples/llava/constants.rs new file mode 100644 index 00000000..dff9ab63 --- /dev/null +++ b/candle-examples/examples/llava/constants.rs @@ -0,0 +1,4 @@ +pub const DEFAULT_IMAGE_TOKEN: &str = ""; +pub const DEFAULT_IM_START_TOKEN: &str = ""; +pub const DEFAULT_IM_END_TOKEN: &str = ""; +pub const IMAGE_PLACEHOLDER: &str = ""; diff --git a/candle-examples/examples/llava/conversation.rs b/candle-examples/examples/llava/conversation.rs new file mode 100644 index 00000000..47436c63 --- /dev/null +++ b/candle-examples/examples/llava/conversation.rs @@ -0,0 +1,114 @@ +pub enum SeparatorStyle { + Two, + Mpt, +} +pub struct Conversation { + pub system: String, + pub roles: Vec, + pub messages: Vec<(String, Option)>, + pub offset: i32, + pub sep_style: SeparatorStyle, + pub sep: String, + pub sep2: Option, + pub version: String, +} + +impl Conversation { + pub fn new( + system: &str, + roles: &[String], + offset: i32, + sep_style: SeparatorStyle, + sep: &str, + sep2: Option<&str>, + version: &str, + ) -> Self { + Conversation { + system: system.to_string(), + roles: roles.to_vec(), + messages: Vec::new(), + offset, + sep_style, + sep: sep.to_string(), + sep2: sep2.map(|s| s.to_string()), + version: version.to_string(), + } + } + + pub fn conv_chatml_direct() -> Self { + Conversation::new( + "<|im_start|>system\nAnswer the questions.", + &[ + "<|im_start|>user\n".to_string(), + "<|im_start|>assistant\n".to_string(), + ], + 0, + SeparatorStyle::Mpt, + "<|im_end|>", + None, + "mpt", + ) + } + + pub fn conv_llava_v1() -> Self { + Conversation::new( + "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.", + &[ + "USER".to_string(), + "ASSISTANT".to_string(), + ], + 0, + SeparatorStyle::Two, + " ", + Some(""), + "v1" + ) + } + + pub fn append_message(&mut self, role: String, message: Option<&str>) { + self.messages.push((role, message.map(|s| s.to_string()))) + } + + pub fn append_user_message(&mut self, message: Option<&str>) { + self.append_message(self.roles[0].clone(), message); + } + + pub fn append_assistant_message(&mut self, message: Option<&str>) { + self.append_message(self.roles[1].clone(), message); + } + + pub fn get_prompt(&self) -> String { + match self.sep_style { + SeparatorStyle::Mpt => { + let mut ret = String::new(); + ret.push_str(&self.system); + ret.push_str(&self.sep); + for (role, message) in &self.messages { + ret.push_str(role); + if let Some(message) = message { + ret.push_str(message); + }; + ret.push_str(&self.sep); + } + ret + } + SeparatorStyle::Two => { + let seps = [self.sep.clone(), self.sep2.clone().unwrap()]; + let mut ret = String::new(); + ret.push_str(&self.system); + ret.push_str(&seps[0]); + for (i, (role, message)) in self.messages.iter().enumerate() { + ret.push_str(role); + if let Some(message) = message { + ret.push_str(": "); // strictly follow the python implementation, otherwise it will cause some minor difference between tokens ^_^ + ret.push_str(message); + ret.push_str(&seps[i % 2]); + } else { + ret.push(':') + } + } + ret + } + } + } +} diff --git a/candle-examples/examples/llava/image_processor.rs b/candle-examples/examples/llava/image_processor.rs new file mode 100644 index 00000000..b50771e5 --- /dev/null +++ b/candle-examples/examples/llava/image_processor.rs @@ -0,0 +1,317 @@ +use std::cmp::min; + +use candle::{bail, DType, Device, Result, Tensor}; +use candle_transformers::models::llava::{ + config::{HFPreProcessorConfig, LLaVAConfig}, + utils::select_best_resolution, +}; +use hf_hub::api::sync::Api; +use image::{imageops::overlay, DynamicImage, GenericImageView, Rgb, RgbImage}; +use serde::{Deserialize, Serialize}; + +//This struct is mainly for LLaVA aplications, hence it's not completely compatible with python transformer CLIPImageProcessor few several preprocess that LLaVA used, including "openai/clip-vit-large-patch14-336" and "openai/clip-vit-large-patch14". + +#[derive(Serialize, Deserialize, Debug)] +pub struct ImageProcessor { + #[serde(default = "default_size")] + pub size: u32, // this is not the same as python transformer + #[serde(default = "default_do_resize")] + pub do_resize: bool, + + //resample: u32 // 3 for PIL bicubic, equivalent to rust CatmullRom. Hence below we use CatmullRom + #[serde(default = "default_do_center_crop")] + pub do_center_crop: bool, + #[serde(default = "default_crop_size")] + pub crop_size: u32, // this is not the same as python transformer + #[serde(default = "default_do_rescale")] + pub do_rescale: bool, + #[serde(default = "default_rescale_factor")] + pub rescale_factor: f32, + #[serde(default = "default_do_normalize")] + pub do_normalize: bool, + #[serde(default = "default_image_mean")] + pub image_mean: Vec, + #[serde(default = "default_image_std")] + pub image_std: Vec, +} + +fn default_size() -> u32 { + 224 +} + +fn default_do_resize() -> bool { + true +} + +fn default_do_center_crop() -> bool { + true +} + +fn default_crop_size() -> u32 { + 224 +} + +fn default_do_rescale() -> bool { + true +} + +fn default_rescale_factor() -> f32 { + 1.0 / 255.0 +} + +fn default_do_normalize() -> bool { + true +} + +fn default_image_mean() -> Vec { + vec![0.48145466, 0.4578275, 0.40821073] +} + +fn default_image_std() -> Vec { + vec![0.26862954, 0.2613026, 0.2757771] +} + +impl ImageProcessor { + pub fn from_pretrained(clip_id: &str) -> Result { + let api = Api::new().map_err(|e| candle::Error::Msg(e.to_string()))?; + let api = api.model(clip_id.to_string()); + let config_filename = api + .get("preprocessor_config.json") + .map_err(|e| candle::Error::Msg(e.to_string()))?; + let image_processor = + serde_json::from_slice(&std::fs::read(config_filename).map_err(candle::Error::Io)?) + .map_err(|e| candle::Error::Msg(e.to_string()))?; + Ok(image_processor) + } + + pub fn from_hf_preprocessor_config(hf_preprocessor_config: &HFPreProcessorConfig) -> Self { + Self { + size: hf_preprocessor_config.size["shortest_edge"] as u32, + do_resize: hf_preprocessor_config.do_resize, + do_center_crop: hf_preprocessor_config.do_center_crop, + crop_size: hf_preprocessor_config.crop_size["height"] as u32, + do_rescale: hf_preprocessor_config.do_rescale, + rescale_factor: hf_preprocessor_config.rescale_factor, + do_normalize: hf_preprocessor_config.do_normalize, + image_mean: hf_preprocessor_config.image_mean.clone(), + image_std: hf_preprocessor_config.image_std.clone(), + } + } + + ///shortest edge to self.resize, other edge is resized to maintain aspect ratio + pub fn resize(&self, image: &DynamicImage) -> DynamicImage { + let (width, height) = image.dimensions(); + let size = self.size; + if width == size && height == size { + image.clone() + } else { + let (new_width, new_height) = if width < height { + ( + size, + (((size * height) as f32) / width as f32).ceil() as u32, + ) + } else { + ( + (((size * width) as f32) / height as f32).ceil() as u32, + size, + ) + }; + image.resize( + new_width, + new_height, + image::imageops::FilterType::CatmullRom, + ) + } + } + + pub fn center_crop(&self, image: &DynamicImage) -> DynamicImage { + let (width, height) = image.dimensions(); + let crop_size = self.crop_size; + let (left, top) = calculate_middle((width, height), (crop_size, crop_size)); + image.crop_imm(left, top, crop_size, crop_size) + } + + pub fn to_tensor(&self, image: &DynamicImage) -> Result { + let img = image.to_rgb8().into_raw(); + let (width, height) = image.dimensions(); + Tensor::from_vec(img, (height as usize, width as usize, 3), &Device::Cpu)? + .to_dtype(DType::F32) // only for internal compute + } + + pub fn rescale(&self, tensor: &Tensor) -> Result { + let rescale_factor = self.rescale_factor as f64; + tensor.affine(rescale_factor, 0.0) + } + + pub fn normalize(&self, tensor: &Tensor) -> Result { + let image_mean = self.image_mean.clone(); + let image_std = self.image_std.clone(); + let mean = Tensor::from_vec(image_mean, (3,), &Device::Cpu)?; + let std = Tensor::from_vec(image_std, (3,), &Device::Cpu)?; + tensor.broadcast_sub(&mean)?.broadcast_div(&std) + } + + pub fn to_channel_dimension_format(&self, tensor: &Tensor) -> Result { + tensor.permute((2, 0, 1)) + } + + pub fn preprocess(&self, image: &DynamicImage) -> Result { + let image = if self.do_resize { + self.resize(image) + } else { + image.clone() + }; + let image = if self.do_center_crop { + self.center_crop(&image) + } else { + image + }; + let tensor = self.to_tensor(&image)?; + let tensor = if self.do_rescale { + self.rescale(&tensor)? + } else { + tensor + }; + let tensor = if self.do_normalize { + self.normalize(&tensor)? + } else { + tensor + }; + self.to_channel_dimension_format(&tensor) + } +} + +pub fn calculate_middle(image_size: (u32, u32), center_size: (u32, u32)) -> (u32, u32) { + let (width, height) = image_size; + let (center_width, center_height) = center_size; + let left = if width <= center_width { + 0 + } else { + ((width as f32 - center_width as f32) / 2.0).ceil() as u32 + }; + let top = if height <= center_height { + 0 + } else { + ((height as f32 - center_height as f32) / 2.0).ceil() as u32 + }; + (left, top) +} + +pub fn process_image( + image: &DynamicImage, + processor: &ImageProcessor, + llava_config: &LLaVAConfig, +) -> candle::Result { + if llava_config.image_aspect_ratio == *"square" { + processor.preprocess(image)?.unsqueeze(0) + } else if llava_config.image_aspect_ratio == *"anyres" { + process_anyres_image(image, processor, &llava_config.image_grid_pinpoints) + } else if llava_config.image_aspect_ratio == *"pad" { + process_pad_image(image, processor) + } else { + bail!("Invalid image aspect ratio") + } +} + +fn process_pad_image(image: &DynamicImage, processor: &ImageProcessor) -> Result { + let mean_color = processor + .image_mean + .iter() + .map(|x| ((*x) * 255.0) as u8) + .collect::>(); + let mean_color = Rgb::from([mean_color[0], mean_color[1], mean_color[2]]); + let image_padded = expand2square(image, mean_color); + processor.preprocess(&image_padded) +} + +fn process_anyres_image( + image: &DynamicImage, + processor: &ImageProcessor, + grid_pinpoints: &[(u32, u32)], +) -> Result { + let original_size = image.dimensions(); + let best_resolution = select_best_resolution(original_size, grid_pinpoints); + let image_padded = resize_and_pad_image(image, best_resolution); + let image_original_resize = image.resize_exact( + processor.size, + processor.size, + image::imageops::FilterType::CatmullRom, + ); + let mut patches = vec![image_original_resize]; + for patch in divide_to_patches(&image_padded, processor.crop_size) { + patches.push(patch); + } + let tensors = patches + .iter() + .map(|patch| processor.preprocess(patch)) + .collect::>>()?; + Tensor::stack(&tensors, 0) +} + +fn expand2square(image: &DynamicImage, background_color: Rgb) -> DynamicImage { + let (width, height) = image.dimensions(); + match width.cmp(&height) { + std::cmp::Ordering::Less => { + let mut new_image = + DynamicImage::from(RgbImage::from_pixel(height, height, background_color)); + overlay(&mut new_image, image, ((height - width) / 2) as i64, 0); + new_image + } + std::cmp::Ordering::Equal => image.clone(), + std::cmp::Ordering::Greater => { + let mut new_image = + DynamicImage::from(RgbImage::from_pixel(width, width, background_color)); + overlay(&mut new_image, image, 0, ((width - height) / 2) as i64); + new_image + } + } +} + +fn resize_and_pad_image(image: &DynamicImage, target_resolution: (u32, u32)) -> DynamicImage { + let (original_width, original_height) = image.dimensions(); + let original_width_f = original_width as f32; + let original_height_f = original_height as f32; + let (target_width, target_height) = target_resolution; + let target_width_f = target_width as f32; + let target_height_f = target_height as f32; + let scale_w = target_width_f / original_width_f; + let scale_h = target_height_f / original_height_f; + let (new_width, new_height) = if scale_w < scale_h { + ( + target_width, + min((original_height_f * scale_w).ceil() as u32, target_height), + ) + } else { + ( + min((original_width_f * scale_h).ceil() as u32, target_width), + target_height, + ) + }; + let resized_image = image.resize_exact( + new_width, + new_height, + image::imageops::FilterType::CatmullRom, + ); + let mut new_image = DynamicImage::new_rgb8(target_width, target_height); + let (paste_x, paste_y) = + calculate_middle((target_width, target_height), (new_width, new_height)); + overlay( + &mut new_image, + &resized_image, + paste_x.into(), + paste_y.into(), + ); + new_image +} + +fn divide_to_patches(image: &DynamicImage, patch_size: u32) -> Vec { + let (width, height) = image.dimensions(); + let mut patches = Vec::new(); + for y in (0..height).step_by(patch_size as usize) { + for x in (0..width).step_by(patch_size as usize) { + let patch = image.crop_imm(x, y, patch_size, patch_size); + patches.push(patch); + } + } + patches +} diff --git a/candle-examples/examples/llava/main.rs b/candle-examples/examples/llava/main.rs new file mode 100644 index 00000000..d6c911af --- /dev/null +++ b/candle-examples/examples/llava/main.rs @@ -0,0 +1,316 @@ +pub mod constants; +pub mod conversation; +pub mod image_processor; + +use candle_transformers::generation::{LogitsProcessor, Sampling}; +use candle_transformers::models::llama::Cache; + +use anyhow::{bail, Error as E, Result}; +use candle::{DType, Device, IndexOp, Tensor}; +use candle_nn::VarBuilder; +use candle_transformers::models::llava::config::{ + HFGenerationConfig, HFLLaVAConfig, HFPreProcessorConfig, +}; +use candle_transformers::models::llava::{config::LLaVAConfig, LLaVA}; +use clap::Parser; +use constants::*; +use conversation::Conversation; +use hf_hub::api::sync::Api; +use image_processor::{process_image, ImageProcessor}; +use std::io::Write; +use tokenizers::Tokenizer; + +#[derive(Parser, Debug)] +#[command(author, version, about,long_about=None)] +struct Args { + #[arg(long, default_value = "llava-hf/llava-v1.6-vicuna-7b-hf")] + model_path: String, + #[arg(long, default_value = "tokenizer/tokenizer.json")] + tokenizer_path: String, + #[arg(long)] + model_base: Option, + #[arg(long)] + image_file: String, // Required + #[arg(long)] + conv_mode: Option, + #[arg(long, default_value_t = 0.2)] + temperature: f32, + #[arg(long, default_value_t = 512)] + max_new_tokens: usize, + #[arg(long, action)] + hf: bool, + #[arg(long, action)] + cpu: bool, + #[arg(long, action)] + no_kv_cache: bool, + #[arg(long)] + prompt: String, + /// The seed to use when generating random samples. Copy from candle llama. Not exist in python llava. + #[arg(long, default_value_t = 299792458)] + seed: u64, +} + +//from https://github.com/huggingface/candle/blob/main/candle-examples/examples/clip/main.rs +fn load_image>( + path: T, + processor: &ImageProcessor, + llava_config: &LLaVAConfig, + dtype: DType, +) -> Result<((u32, u32), Tensor)> { + let img = image::io::Reader::open(path)?.decode()?; + let img_tensor = process_image(&img, processor, llava_config)?; + Ok(((img.width(), img.height()), img_tensor.to_dtype(dtype)?)) +} + +fn get_model_name_from_path(model_path: &str) -> String { + let model_paths: Vec = model_path + .trim_matches('/') + .split('/') + .map(|s| s.to_string()) + .collect(); + if model_paths.last().unwrap().starts_with("checkpoint-") { + format!( + "{}_{}", + model_paths[model_paths.len() - 2], + model_paths.last().unwrap() + ) + } else { + model_paths.last().unwrap().to_string() + } +} + +fn duplicate_vec(vec: &[T], n: usize) -> Vec +where + T: Clone, +{ + let mut res = Vec::new(); + for _ in 0..n { + res.extend(vec.to_owned()); + } + res +} + +fn insert_separator(x: Vec>, sep: Vec) -> Vec> +where + T: Clone, +{ + let sep = vec![sep]; + let sep = duplicate_vec(&sep, x.len()); + let mut res = x + .iter() + .zip(sep.iter()) + .flat_map(|(x, y)| vec![x.clone(), y.clone()]) + .collect::>>(); + res.pop(); + res +} + +fn tokenizer_image_token( + prompt: &str, + tokenizer: &Tokenizer, + image_token_index: i64, + llava_config: &LLaVAConfig, +) -> Result { + let prompt_chunks = prompt + .split("") + .map(|s| { + tokenizer + .encode(s, true) + .unwrap() + .get_ids() + .to_vec() + .iter() + .map(|x| *x as i64) + .collect() + }) + .collect::>>(); + let mut input_ids = Vec::new(); + let mut offset = 0; + if !prompt_chunks.is_empty() + && !prompt_chunks[0].is_empty() + && prompt_chunks[0][0] == llava_config.bos_token_id as i64 + { + offset = 1; + input_ids.push(prompt_chunks[0][0]); + } + + for x in insert_separator( + prompt_chunks, + duplicate_vec(&[image_token_index], offset + 1), + ) + .iter() + { + input_ids.extend(x[1..].to_vec()) + } + let input_len = input_ids.len(); + Tensor::from_vec(input_ids, (1, input_len), &Device::Cpu).map_err(E::msg) +} + +fn main() -> Result<()> { + let mut args = Args::parse(); + let device = candle_examples::device(args.cpu)?; + println!("Start loading model"); + let api = Api::new()?; + let api = api.model(args.model_path.clone()); + let (llava_config, tokenizer, clip_vision_config, image_processor) = if args.hf { + let config_filename = api.get("config.json")?; + let hf_llava_config: HFLLaVAConfig = + serde_json::from_slice(&std::fs::read(config_filename)?)?; + let generation_config_filename = api.get("generation_config.json")?; + let generation_config: HFGenerationConfig = + serde_json::from_slice(&std::fs::read(generation_config_filename)?)?; + let preprocessor_config_filename = api.get("preprocessor_config.json")?; + let preprocessor_config: HFPreProcessorConfig = + serde_json::from_slice(&std::fs::read(preprocessor_config_filename)?)?; + let llava_config = + hf_llava_config.to_llava_config(&generation_config, &preprocessor_config); + let tokenizer_filename = api.get("tokenizer.json")?; + let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; + let clip_vision_config = hf_llava_config.to_clip_vision_config(); + ( + llava_config, + tokenizer, + Some(clip_vision_config), + ImageProcessor::from_hf_preprocessor_config(&preprocessor_config), + ) + } else { + let config_filename = api.get("config.json")?; + let llava_config: LLaVAConfig = serde_json::from_slice(&std::fs::read(config_filename)?)?; + let tokenizer = Tokenizer::from_file(&args.tokenizer_path) + .map_err(|e| E::msg(format!("Error loading {}: {}", &args.tokenizer_path, e)))?; + ( + llava_config.clone(), + tokenizer, + None, + ImageProcessor::from_pretrained(&llava_config.mm_vision_tower.unwrap())?, + ) + }; + + let llama_config = llava_config.to_llama_config(); + let dtype: DType = match llava_config.torch_dtype.as_str() { + "float16" => DType::F16, + "bfloat16" => DType::BF16, + _ => bail!("unsupported dtype"), + }; + + let eos_token_id = llava_config.eos_token_id; + + println!("setting kv cache"); + let mut cache = Cache::new(!args.no_kv_cache, dtype, &llama_config, &device)?; + + println!("loading model weights"); + + let weight_filenames = + candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?; + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&weight_filenames, dtype, &device)? }; + let llava: LLaVA = LLaVA::load(vb, &llava_config, clip_vision_config)?; + + println!("generating conv template"); + let image_token_se = format!( + "{}{}{}", + DEFAULT_IM_START_TOKEN, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_END_TOKEN + ); + let qs = if args.prompt.contains(IMAGE_PLACEHOLDER) { + if llava_config.mm_use_im_start_end { + args.prompt.replace(IMAGE_PLACEHOLDER, &image_token_se) + } else { + args.prompt.replace(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN) + } + } else if llava_config.mm_use_im_start_end { + format!("{}\n{}", image_token_se, args.prompt) + } else { + format!("{}\n{}", DEFAULT_IMAGE_TOKEN, args.prompt) + }; + + let model_name = get_model_name_from_path(&args.model_path).to_lowercase(); + let conv_mode = if model_name.contains("llama-2") { + "llava_llama_2" + } else if model_name.contains("mistral") { + "mistral_instruct" + } else if model_name.contains("v1.6-34b") { + "chatml_direct" + } else if model_name.contains("v1") { + "llava_v1" + } else if model_name.contains("mpt") { + "mpt" + } else { + "llava_v0" + }; + if args.conv_mode.is_some() && args.conv_mode.as_deref() != Some(conv_mode) { + println!( + "Warning: the model is trained with {}, but you are using {}", + conv_mode, + args.conv_mode.as_deref().unwrap() + ); + } else { + args.conv_mode = Some(conv_mode.to_string()); + } + + let mut conv = match args.conv_mode { + Some(conv_mode) => match conv_mode.as_str() { + "chatml_direct" => Conversation::conv_chatml_direct(), + "llava_v1" => Conversation::conv_llava_v1(), + _ => todo!("not implement yet"), + }, + None => bail!("conv_mode is required"), + }; + conv.append_user_message(Some(&qs)); + conv.append_assistant_message(None); + let prompt = conv.get_prompt(); + println!("loading image"); + let (image_size, image_tensor) = + load_image(&args.image_file, &image_processor, &llava_config, dtype) + .map_err(|e| E::msg(format!("Error loading {}: {}", &args.image_file, e)))?; + let image_tensor = image_tensor.to_device(&device)?; + + let mut logits_processor = { + let temperature = f64::from(args.temperature); + let sampling = if temperature <= 0. { + Sampling::ArgMax + } else { + Sampling::All { temperature } + }; + LogitsProcessor::from_sampling(args.seed, sampling) + }; + + // get input tokens + let tokens = tokenizer_image_token( + &prompt, + &tokenizer, + llava_config.image_token_index as i64, + &llava_config, + )?; + let mut input_embeds = + llava.prepare_inputs_labels_for_multimodal(&tokens, &[image_tensor], &[image_size])?; + //inference loop, based on https://github.com/huggingface/candle/blob/main/candle-examples/examples/llama/main.rs + let mut tokenizer = candle_examples::token_output_stream::TokenOutputStream::new(tokenizer); + let mut index_pos = 0; + for index in 0..args.max_new_tokens { + let (_, input_embeds_len, _) = input_embeds.dims3()?; + let (context_size, context_index) = if cache.use_kv_cache && index > 0 { + (1, index_pos) + } else { + (input_embeds_len, 0) + }; + let input = input_embeds.i((.., input_embeds_len.saturating_sub(context_size).., ..))?; + let logits = llava.forward(&input, context_index, &mut cache)?; //[1,32000] + let logits = logits.squeeze(0)?; + let (_, input_len, _) = input.dims3()?; + index_pos += input_len; + let next_token = logits_processor.sample(&logits)?; + let next_token_tensor = Tensor::from_vec(vec![next_token], 1, &device)?; + let next_embeds = llava.llama.embed(&next_token_tensor)?.unsqueeze(0)?; + input_embeds = Tensor::cat(&[input_embeds, next_embeds], 1)?; + if next_token == eos_token_id as u32 { + break; + } + if let Some(t) = tokenizer.next_token(next_token)? { + print!("{t}"); + std::io::stdout().flush()?; + } + } + if let Some(rest) = tokenizer.decode_rest().map_err(E::msg)? { + print!("{rest}"); + } + Ok(()) +} diff --git a/candle-examples/examples/llava/readme.md b/candle-examples/examples/llava/readme.md new file mode 100644 index 00000000..7ce84970 --- /dev/null +++ b/candle-examples/examples/llava/readme.md @@ -0,0 +1,40 @@ +# candle-llava + +LLaVA (Large Language-and-Vision Assistant) is an end-to-end trained large +multimodal model. This example is from [candle-llava](https://github.com/chenwanqq/candle-llava) + +The code is based on [https://github.com/haotian-liu/LLaVA](https://github.com/haotian-liu/LLaVA), Hence the llava-hf version of config may perform differently. + +## model zoo +* [liuhaotian/LLaVA](https://huggingface.co/liuhaotian) +* [llava-hf](https://huggingface.co/llava-hf) + +Right now this has been tested on `liuhaotian/llava-v1.6-vicuna-7b` and +`llava-hf/llava-v1.6-vicuna-7b-hf`. Memory usage might have room for optimization. + +## Tokenizer Setup +The llava-hf models contain a `tokenizer.json` file so can be used directly with +the `-hf` command line flag. + +For the original llava models, you can use the following code to generate the `tokenizer.json` file. + +```bash +conda create -n llava python=3.10 +pip install transformers protobuf +conda activate llava +python -c "from transformers import AutoTokenizer;tokenizer=AutoTokenizer.from_pretrained('liuhaotian/llava-v1.6-vicuna-7b');tokenizer.save_pretrained('tokenizer')" +``` +Then the `tokenizer.json` file should be in `tokenizer/tokenizer.json` (which is the default path). + + +## eval + +```bash +cargo run --example llava --features cuda -- --image-file "llava_logo.png" --prompt "is this a cat?" --hf # default args, use llava-hf/llava-v1.6-vicuna-7b-hf. image-file is required^_^ +cargo run --example llava --features cuda -- --model-path liuhaotian/llava-v1.6-vicuna-7b --image-file "llava_logo.png" --prompt "is this a cat?" # use liuhaotian/llava-v1.6-vicuna-7b, tokenizer setup should be done +``` + +## Major Limitations +1. Currently only support llama-2/vicuna llm. Haven't supoort Mistral yet. +2. There are some ops like split, nonzero and where are not supported by candle. +3. Lack of quantization and LoRA support. diff --git a/candle-transformers/src/models/clip/text_model.rs b/candle-transformers/src/models/clip/text_model.rs index d3ba26ff..4e4b4c90 100644 --- a/candle-transformers/src/models/clip/text_model.rs +++ b/candle-transformers/src/models/clip/text_model.rs @@ -262,6 +262,20 @@ impl ClipEncoder { } Ok(xs) } + // required by LLaVA + pub fn output_hidden_states( + &self, + xs: &Tensor, + causal_attention_mask: Option<&Tensor>, + ) -> Result> { + let mut xs = xs.clone(); + let mut hidden_states = Vec::new(); + for layer in self.layers.iter() { + xs = layer.forward(&xs, causal_attention_mask)?; + hidden_states.push(xs.clone()); + } + Ok(hidden_states) + } } /// A CLIP transformer based model. diff --git a/candle-transformers/src/models/clip/vision_model.rs b/candle-transformers/src/models/clip/vision_model.rs index 88992434..e64cab16 100644 --- a/candle-transformers/src/models/clip/vision_model.rs +++ b/candle-transformers/src/models/clip/vision_model.rs @@ -46,6 +46,19 @@ impl ClipVisionConfig { patch_size: 32, } } + pub fn clip_vit_large_patch14_336() -> Self { + Self { + embed_dim: 1024, + activation: Activation::QuickGelu, + intermediate_size: 4096, + num_hidden_layers: 24, + num_attention_heads: 16, + projection_dim: 768, + num_channels: 3, + image_size: 336, + patch_size: 14, + } + } } // https://github.com/huggingface/transformers/blob/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip/modeling_clip.py#L112 @@ -130,6 +143,17 @@ impl ClipVisionTransformer { pre_layer_norm, }) } + // required by LLaVA + pub fn output_hidden_states(&self, pixel_values: &Tensor) -> Result> { + let hidden_states = pixel_values + .apply(&self.embeddings)? + .apply(&self.pre_layer_norm)?; + let mut result = self.encoder.output_hidden_states(&hidden_states, None)?; + let encoder_outputs = result.last().unwrap(); + let pooled_output = encoder_outputs.i((.., 0, ..))?; + result.push(self.final_layer_norm.forward(&pooled_output)?.clone()); + Ok(result) + } } impl Module for ClipVisionTransformer { diff --git a/candle-transformers/src/models/llama.rs b/candle-transformers/src/models/llama.rs index 57d2f593..a1f43d35 100644 --- a/candle-transformers/src/models/llama.rs +++ b/candle-transformers/src/models/llama.rs @@ -388,6 +388,28 @@ pub struct Llama { } impl Llama { + // required by LLaVA + pub fn embed(&self, x: &Tensor) -> Result { + self.wte.forward(x) + } + // required by LLaVA + pub fn forward_input_embed( + &self, + input_embed: &Tensor, + index_pos: usize, + cache: &mut Cache, + ) -> Result { + let (_, seq_len, _) = input_embed.dims3()?; + let mut x = input_embed.clone(); + for (block_idx, block) in self.blocks.iter().enumerate() { + x = block.forward(&x, index_pos, block_idx, cache)?; + } + let x = self.ln_f.forward(&x)?; + let x = x.i((.., seq_len - 1, ..))?.contiguous()?; + let logits = self.lm_head.forward(&x)?; + logits.to_dtype(DType::F32) + } + pub fn forward(&self, x: &Tensor, index_pos: usize, cache: &mut Cache) -> Result { let (_b_sz, seq_len) = x.dims2()?; let mut x = self.wte.forward(x)?; diff --git a/candle-transformers/src/models/llava/config.rs b/candle-transformers/src/models/llava/config.rs new file mode 100644 index 00000000..d2d47003 --- /dev/null +++ b/candle-transformers/src/models/llava/config.rs @@ -0,0 +1,267 @@ +use std::collections::HashMap; + +use crate::models::{ + clip::{text_model::Activation, vision_model::ClipVisionConfig}, + llama::Config, +}; +use serde::{Deserialize, Serialize}; + +// original config from liuhaotian/llava +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct LLaVAConfig { + pub architectures: Vec, + pub bos_token_id: usize, + pub eos_token_id: usize, + pub hidden_size: usize, + #[serde(default = "default_image_aspect_ratio")] + pub image_aspect_ratio: String, + pub image_crop_resolution: usize, + pub image_grid_pinpoints: Vec<(u32, u32)>, + pub image_split_resolution: usize, + pub intermediate_size: usize, + pub max_position_embeddings: usize, + pub mm_hidden_size: usize, + #[serde(default = "default_mm_patch_merge_type")] + pub mm_patch_merge_type: String, + pub mm_projector_type: String, + pub mm_use_im_start_end: bool, + pub mm_vision_select_feature: String, + pub mm_vision_select_layer: isize, + pub mm_vision_tower: Option, + pub model_type: String, + pub num_attention_heads: usize, + pub num_hidden_layers: usize, + pub num_key_value_heads: usize, + pub pad_token_id: usize, + pub rms_norm_eps: f32, + pub rope_theta: f32, + pub tokenizer_model_max_length: Option, + pub torch_dtype: String, + pub use_cache: bool, + pub vocab_size: usize, + #[serde(default = "default_image_token_index")] + pub image_token_index: isize, + #[serde(default = "default_hf")] + pub hf: bool, +} + +fn default_hf() -> bool { + false +} + +fn default_image_token_index() -> isize { + -200 +} + +fn default_mm_patch_merge_type() -> String { + "flat".to_string() +} + +fn default_image_aspect_ratio() -> String { + "square".to_string() +} + +impl LLaVAConfig { + pub fn to_llama_config(&self) -> Config { + Config { + hidden_size: self.hidden_size, + intermediate_size: self.intermediate_size, + vocab_size: self.vocab_size, + num_hidden_layers: self.num_hidden_layers, + num_attention_heads: self.num_attention_heads, + num_key_value_heads: self.num_key_value_heads, + rms_norm_eps: self.rms_norm_eps as f64, + rope_theta: self.rope_theta, + bos_token_id: Some(self.bos_token_id as u32), + eos_token_id: Some(self.eos_token_id as u32), + use_flash_attn: false, + } + } +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct HFLLaVATextConfig { + pub architectures: Vec, + #[serde(default = "default_hidden_size")] + pub hidden_size: usize, + #[serde(default = "default_intermediate_size")] + pub intermediate_size: usize, + #[serde(default = "default_max_length")] + pub max_length: usize, + pub max_position_embeddings: usize, + pub model_type: String, + #[serde(default = "default_num_attention_heads")] + pub num_attention_heads: usize, + #[serde(default = "default_num_hidden_layers")] + pub num_hidden_layers: usize, + #[serde(default = "default_num_key_value_heads")] + pub num_key_value_heads: usize, + pub pad_token_id: usize, + pub rms_norm_eps: f32, + #[serde(default = "default_rope_theta")] + pub rope_theta: f32, + pub torch_dtype: String, + #[serde(default = "default_use_cache")] + pub use_cache: bool, + pub vocab_size: usize, +} + +fn default_num_hidden_layers() -> usize { + 32 +} + +fn default_use_cache() -> bool { + true +} + +fn default_hidden_size() -> usize { + 4096 +} + +fn default_intermediate_size() -> usize { + 11008 +} + +fn default_max_length() -> usize { + 4096 +} + +fn default_num_attention_heads() -> usize { + 32 +} + +fn default_num_key_value_heads() -> usize { + 32 +} + +fn default_rope_theta() -> f32 { + 10000.0 +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct HFLLaVAVisionConfig { + pub hidden_size: usize, + pub image_size: usize, + pub intermediate_size: usize, + pub model_type: String, + pub num_attention_heads: usize, + pub num_hidden_layers: usize, + pub patch_size: usize, + pub projection_dim: usize, + pub vocab_size: usize, +} + +// config from llava-v1.6-vicuna-7b-hf +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct HFLLaVAConfig { + pub architectures: Vec, + pub ignore_index: isize, + pub image_grid_pinpoints: Vec<(u32, u32)>, + pub image_token_index: isize, + pub model_type: String, + pub projector_hidden_act: String, + pub text_config: HFLLaVATextConfig, + pub torch_dtype: String, + pub use_image_newline_parameter: bool, + pub vision_config: HFLLaVAVisionConfig, + pub vision_feature_layer: isize, + pub vision_feature_select_strategy: String, + pub vocab_size: usize, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct HFGenerationConfig { + pub bos_token_id: usize, + pub eos_token_id: usize, + #[serde(default = "default_max_length")] + pub max_length: usize, + pub pad_token_id: usize, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct HFPreProcessorConfig { + pub aspect_ratio_setting: String, + pub crop_size: HashMap, + pub do_center_crop: bool, + pub do_convert_rgb: bool, + pub do_normalize: bool, + pub do_rescale: bool, + pub do_resize: bool, + pub image_mean: Vec, + pub image_std: Vec, + pub resample: u32, + pub rescale_factor: f32, + pub size: HashMap, +} + +impl HFLLaVAConfig { + pub fn to_clip_vision_config(&self) -> ClipVisionConfig { + ClipVisionConfig { + embed_dim: self.vision_config.hidden_size, + activation: Activation::QuickGelu, + intermediate_size: self.vision_config.intermediate_size, + num_hidden_layers: self.vision_config.num_hidden_layers, + num_attention_heads: self.vision_config.num_attention_heads, + projection_dim: self.vision_config.projection_dim, + num_channels: 3, + image_size: self.vision_config.image_size, + patch_size: self.vision_config.patch_size, + } + } + fn map_projector_type(s: &str) -> String { + if s == "gelu" { + "mlp2x_gelu".to_string() + } else { + s.to_string() + } + } + + fn map_select_feature(s: &str) -> String { + if s == "default" { + "patch".to_string() + } else { + "cls_patch".to_string() + } + } + + pub fn to_llava_config( + &self, + generation_config: &HFGenerationConfig, + preprocessor_config: &HFPreProcessorConfig, + ) -> LLaVAConfig { + LLaVAConfig { + hf: true, + architectures: self.architectures.clone(), + bos_token_id: generation_config.bos_token_id, + eos_token_id: generation_config.eos_token_id, + hidden_size: self.text_config.hidden_size, + image_aspect_ratio: preprocessor_config.aspect_ratio_setting.clone(), + image_crop_resolution: 224, + image_grid_pinpoints: self.image_grid_pinpoints.clone(), + image_split_resolution: 224, + intermediate_size: self.text_config.intermediate_size, + max_position_embeddings: self.text_config.max_position_embeddings, + mm_hidden_size: 1024, + mm_patch_merge_type: "spatial_unpad".to_string(), + mm_projector_type: Self::map_projector_type(&self.projector_hidden_act), + mm_use_im_start_end: false, + mm_vision_select_feature: Self::map_select_feature( + &self.vision_feature_select_strategy, + ), + mm_vision_select_layer: self.vision_feature_layer, + mm_vision_tower: None, + model_type: self.model_type.clone(), + num_attention_heads: self.text_config.num_attention_heads, + num_hidden_layers: self.text_config.num_hidden_layers, + num_key_value_heads: self.text_config.num_key_value_heads, + pad_token_id: self.text_config.pad_token_id, + rms_norm_eps: self.text_config.rms_norm_eps, + rope_theta: self.text_config.rope_theta, + tokenizer_model_max_length: Some(4096), + torch_dtype: self.torch_dtype.clone(), + use_cache: self.text_config.use_cache, + vocab_size: self.vocab_size, + image_token_index: self.image_token_index, + } + } +} diff --git a/candle-transformers/src/models/llava/mod.rs b/candle-transformers/src/models/llava/mod.rs new file mode 100644 index 00000000..caa8737a --- /dev/null +++ b/candle-transformers/src/models/llava/mod.rs @@ -0,0 +1,407 @@ +pub mod config; +pub mod utils; + +use crate::models::clip::vision_model::{ClipVisionConfig, ClipVisionTransformer}; +use crate::models::llama::{Cache, Llama}; +use crate::models::with_tracing::linear; + +use candle::{bail, Device, IndexOp, Result, Tensor}; +use candle_nn::{seq, Activation, Module, Sequential, VarBuilder}; +use fancy_regex::Regex; +use utils::get_anyres_image_grid_shape; + +use config::LLaVAConfig; + +fn mlp_gelu_match(mm_projector_type: &str) -> Option { + let mlp_gelu_regex = Regex::new(r"^mlp(\d+)x_gelu$").unwrap(); + + if let Ok(Some(captures)) = mlp_gelu_regex.captures(mm_projector_type) { + if let Some(match_str) = captures.get(1) { + let match_str = match_str.as_str(); + match_str.parse::().ok() + } else { + None + } + } else { + None + } +} + +fn unpad_image(tensor: &Tensor, original_size: &(u32, u32)) -> Result { + assert_eq!(tensor.dims().len(), 3); + let (original_width, original_height) = *original_size; + let tensor_dims = tensor.dims(); + let current_height = tensor_dims[1]; + let current_width = tensor_dims[2]; + let original_aspect_ratio = (original_width as f32) / (original_height as f32); + let current_aspect_ratio = (current_width as f32) / (current_height as f32); + if original_aspect_ratio > current_aspect_ratio { + let scale_factor = (current_width as f32) / (original_width as f32); + let new_height = (original_height as f32 * scale_factor).floor() as usize; + let padding = (current_height - new_height) / 2; + tensor.i((.., padding..current_width - padding, ..)) + } else { + let scale_factor = (current_height as f32) / (original_height as f32); + let new_width = (original_width as f32 * scale_factor).floor() as usize; + let padding = (current_width - new_width) / 2; + tensor.i((.., .., padding..current_width - padding)) + } +} + +pub struct IdentityMap {} + +impl Module for IdentityMap { + fn forward(&self, x: &Tensor) -> Result { + Ok(x.clone()) + } +} + +pub struct MMProjector { + pub modules: Sequential, +} + +impl MMProjector { + pub fn load(vb: &VarBuilder, config: &LLaVAConfig) -> Result { + if config.mm_projector_type == "linear" { + let vb_prefix = if config.hf { + "multi_modal_projector.linear_1" + } else { + "model.mm_projector.0" + }; + let linear = linear(config.mm_hidden_size, config.hidden_size, vb.pp(vb_prefix))?; + let modules = seq().add(linear); + Ok(Self { modules }) + } else if let Some(mlp_depth) = mlp_gelu_match(&config.mm_projector_type) { + let modules = if config.hf { + let mut modules = seq().add(linear( + config.mm_hidden_size, + config.hidden_size, + vb.pp("multi_modal_projector.linear_1"), + )?); + for i in 1..mlp_depth { + modules = modules.add(Activation::Gelu).add(linear( + config.hidden_size, + config.hidden_size, + vb.pp(format!("multi_modal_projector.linear_{}", i + 1)), + )?); + } + modules + } else { + let mut modules = seq().add(linear( + config.mm_hidden_size, + config.hidden_size, + vb.pp("model.mm_projector.0"), + )?); + for i in 1..mlp_depth { + modules = modules.add(Activation::Gelu).add(linear( + config.hidden_size, + config.hidden_size, + vb.pp(format!("model.mm_projector.{}", i * 2)), + )?); + } + modules + }; + Ok(Self { modules }) + } else if config.mm_projector_type == "identity" { + Ok(Self { + modules: seq().add(IdentityMap {}), + }) + } else { + bail!( + "Unsupported MM projector type: {}", + config.mm_projector_type + ) + } + } + + pub fn forward(&self, x: &Tensor) -> Result { + self.modules.forward(x) + } +} + +pub struct ClipVisionTower { + model: ClipVisionTransformer, + select_layer: isize, + select_feature_method: String, + pub config: ClipVisionConfig, +} + +impl ClipVisionTower { + pub fn new( + vb: VarBuilder, + select_layer: isize, + select_feature_method: &str, + config: &Option, + ) -> Result { + let config = if config.is_none() { + ClipVisionConfig::clip_vit_large_patch14_336() + } else { + config.clone().unwrap() + }; + let select_layer = match select_layer { + -1 | -2 => select_layer, + _ => bail!("Unsupported select layer: {}", select_layer), + }; + let model = ClipVisionTransformer::new(vb, &config)?; + Ok(Self { + model, + select_layer, + select_feature_method: select_feature_method.to_string(), + config, + }) + } + + pub fn forward(&self, x: &Tensor) -> Result { + let result = self.model.output_hidden_states(x)?; + let index = result.len() as isize + self.select_layer; + let result = result[index as usize].clone(); + if self.select_feature_method == "cls_patch" { + Ok(result) + } else { + result.i((.., 1..)) + } + } + + pub fn num_patches_per_side(&self) -> usize { + self.config.image_size / self.config.patch_size + } +} + +pub struct LLaVA { + pub clip_vision_tower: ClipVisionTower, + pub image_newline: Tensor, + pub mm_projector: MMProjector, + pub llama: Llama, + config: LLaVAConfig, + device: Device, +} + +impl LLaVA { + pub fn load( + vb: VarBuilder, + config: &LLaVAConfig, + clip_vision_config: Option, + ) -> Result { + let device = vb.device().clone(); + let llama_config = config.to_llama_config(); + let mm_projector = MMProjector::load(&vb, config)?; + let (clip_vision_tower, image_newline, llama) = if config.hf { + ( + ClipVisionTower::new( + vb.pp("vision_tower.vision_model"), + config.mm_vision_select_layer, + &config.mm_vision_select_feature, + &clip_vision_config, + )?, + vb.get(&[config.hidden_size], "image_newline")? + .to_device(&device)?, + Llama::load(vb.pp("language_model"), &llama_config)?, + ) + } else { + ( + ClipVisionTower::new( + vb.pp("model.vision_tower.vision_tower.vision_model"), + config.mm_vision_select_layer, + &config.mm_vision_select_feature, + &clip_vision_config, + )?, + vb.get(&[config.hidden_size], "model.image_newline")? + .to_device(&device)?, + Llama::load(vb, &llama_config)?, + ) + }; + Ok(Self { + clip_vision_tower, + image_newline, + mm_projector, + llama, + config: (*config).clone(), + device, + }) + } + + pub fn encode_images(&self, x: &Tensor) -> Result { + let image_features = self.clip_vision_tower.forward(x)?; + let image_features = self.mm_projector.forward(&image_features)?; + Ok(image_features) + } + // currently only for single image, 4 dim tensor + pub fn prepare_inputs_labels_for_multimodal( + &self, + input_ids: &Tensor, + images: &[Tensor], + image_sizes: &[(u32, u32)], + ) -> Result { + //TODO: process of multiple images/ new line + // 576: 336(input size)/14(patch size)=24 24*24+1(class)=577 577-1=576 + let concat_images = Tensor::cat(images, 0)?; + let image_features_together = self.encode_images(&concat_images)?; + let split_sizes = images + .iter() + .map(|x| x.shape().dims()[0]) + .collect::>(); + // can be replaced by split + let mut index_pos = 0; + let mut image_features = Vec::new(); + for split_size in split_sizes.iter() { + image_features.push(image_features_together.i(index_pos..index_pos + (*split_size))?); + index_pos += *split_size; + } + let mm_patch_merge_type = &self.config.mm_patch_merge_type; + let image_aspect_ratio = &self.config.image_aspect_ratio; + + let image_features = if mm_patch_merge_type == "flat" { + image_features + .iter() + .map(|x| x.flatten(0, 1).unwrap()) + .collect::>() + } else if mm_patch_merge_type.starts_with("spatial") { + let mut new_image_features = Vec::new(); + for (image_idx, image_feature) in image_features.iter().enumerate() { + let new_image_feature = if image_feature.dims()[0] > 1 { + let base_image_feature = image_feature.get(0).unwrap(); + let patch_image_feature = image_feature.i(1..).unwrap(); + let height = self.clip_vision_tower.num_patches_per_side(); + let width = height; + assert_eq!(height * width, base_image_feature.dims()[0]); + let image_size = image_sizes[image_idx]; + let new_image_feature = if image_aspect_ratio == "anyres" { + let (num_patch_width, num_patch_height) = get_anyres_image_grid_shape( + image_size, + &self.config.image_grid_pinpoints, + self.clip_vision_tower.config.image_size as u32, + ); + patch_image_feature.reshape(( + num_patch_height as usize, + num_patch_width as usize, + height, + width, + (), + ))? + } else { + todo!("not implemented in original python LLaVA yet") + }; + let new_image_feature = if mm_patch_merge_type.contains("unpad") { + let new_image_feature = new_image_feature + .permute((4, 0, 2, 1, 3))? + .flatten(1, 2)? + .flatten(2, 3)?; + let new_image_feature = unpad_image(&new_image_feature, &image_size)?; + let new_image_feature_dims = new_image_feature.dims(); + let image_new_line = self + .image_newline + .reshape((self.config.hidden_size, 1, 1))? + .broadcast_as(( + new_image_feature_dims[0], + new_image_feature_dims[1], + 1, + ))?; + let new_image_feature = + Tensor::cat(&[new_image_feature, image_new_line], 2)?; + new_image_feature.flatten(1, 2)?.transpose(0, 1)? + } else { + new_image_feature.permute((0, 2, 1, 3, 4))?.flatten(0, 3)? + }; + Tensor::cat(&[base_image_feature, new_image_feature], 0)? + } else { + let new_image_feature = image_feature.get(0).unwrap(); + if mm_patch_merge_type.contains("unpad") { + Tensor::cat( + &[ + new_image_feature, + self.image_newline.clone().unsqueeze(0).unwrap(), + ], + 0, + ) + .unwrap() + } else { + new_image_feature + } + }; + new_image_features.push(new_image_feature); + } + new_image_features + } else { + bail!("Unexpected mm_patch_merge_type: {mm_patch_merge_type}") + }; + // can easily be replaced by nonzero if it is implemented in candle + let input_ids_vec = input_ids.squeeze(0)?.to_vec1::()?; + let mut image_indices = { + let mut image_indices = vec![0_i64]; + image_indices.extend( + input_ids_vec + .iter() + .enumerate() + .filter_map(|(i, x)| { + if *x == self.config.image_token_index as i64 { + Some(i as i64) + } else { + None + } + }) + .collect::>(), + ); + image_indices + }; + if image_indices.len() == 1 { + //no image, only [0], + return self.llama.embed(input_ids); + } + + let input_ids_noim = input_ids_vec + .iter() + .filter_map(|x| { + if *x != self.config.image_token_index as i64 { + Some(*x) + } else { + None + } + }) + .collect::>(); + let input_ids_noim_len = input_ids_noim.len(); + image_indices.push((input_ids_noim_len) as i64); + let input_ids_noim = Tensor::from_vec(input_ids_noim, input_ids_noim_len, &self.device)?; + let cur_input_embeds = self.llama.embed(&input_ids_noim)?; + // can be replace by split if it is implemented in candle + let input_embed_no_ims = { + let mut input_embeds = Vec::new(); + for i in 0..image_indices.len() - 1 { + let start = (image_indices[i]) as usize; + let end = image_indices[i + 1] as usize; + input_embeds.push(cur_input_embeds.i((start..end, ..))?) + } + input_embeds + }; + + let mut cur_new_input_embeds = Vec::new(); + for (i, image_feature) in image_features.iter().enumerate() { + cur_new_input_embeds.push(input_embed_no_ims[i].clone()); + cur_new_input_embeds.push(image_feature.clone()); + } + cur_new_input_embeds.push(input_embed_no_ims[image_features.len()].clone()); + let new_input_embeds = Tensor::cat(&cur_new_input_embeds, 0)?; + //trancate + let new_input_embeds = + if let Some(tokenizer_model_max_length) = self.config.tokenizer_model_max_length { + let (new_input_embeds_length, _) = new_input_embeds.shape().dims2()?; + if new_input_embeds_length > tokenizer_model_max_length { + new_input_embeds.i((..tokenizer_model_max_length, ..))? + } else { + new_input_embeds + } + } else { + new_input_embeds + }; + new_input_embeds.unsqueeze(0) + } + + pub fn forward( + &self, + input_embeds: &Tensor, + position_id: usize, + cache: &mut Cache, + ) -> Result { + self.llama + .forward_input_embed(input_embeds, position_id, cache) + } +} diff --git a/candle-transformers/src/models/llava/utils.rs b/candle-transformers/src/models/llava/utils.rs new file mode 100644 index 00000000..3b4c18bb --- /dev/null +++ b/candle-transformers/src/models/llava/utils.rs @@ -0,0 +1,41 @@ +pub fn get_anyres_image_grid_shape( + image_size: (u32, u32), + grid_pinpoints: &[(u32, u32)], + patch_size: u32, +) -> (u32, u32) { + let (width, height) = select_best_resolution(image_size, grid_pinpoints); + (width / patch_size, height / patch_size) +} + +pub fn select_best_resolution( + original_size: (u32, u32), + possible_resolutions: &[(u32, u32)], +) -> (u32, u32) { + let (original_width, original_height) = original_size; + let mut best_fit = (0, 0); + let original_width_f = original_width as f32; + let original_height_f = original_height as f32; + let mut max_effective_resolution = 0_u32; + let mut min_wasted_resolution = u32::MAX; + for (width, height) in possible_resolutions { + let width_f = *width as f32; + let height_f = *height as f32; + let scale = (width_f / original_width_f).min(height_f / original_height_f); + let (downscaled_width, downscaled_height) = ( + (original_width_f * scale) as u32, + (original_height_f * scale) as u32, + ); + let effective_resolution = + std::cmp::min((*width) * (*height), downscaled_width * downscaled_height); + let wasted_resolution = (*width) * (*height) - effective_resolution; + if effective_resolution > max_effective_resolution + || (effective_resolution == max_effective_resolution + && wasted_resolution < min_wasted_resolution) + { + best_fit = (*width, *height); + max_effective_resolution = effective_resolution; + min_wasted_resolution = wasted_resolution; + } + } + best_fit +} diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index de2430a2..4628a3de 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -17,6 +17,7 @@ pub mod jina_bert; pub mod llama; pub mod llama2_c; pub mod llama2_c_weights; +pub mod llava; pub mod mamba; pub mod marian; pub mod metavoice;