mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
Add LLaVA support (#2234)
* first commit * llava * clippy and fmt * some fixes * minor fixes * remove useless file * refactor: Remove llava/constants.rs and update llava/mod.rs * modify variable name * modify code after clippy * Minor tweaks. --------- Co-authored-by: laurent <laurent.mazare@gmail.com>
This commit is contained in:
4
candle-examples/examples/llava/constants.rs
Normal file
4
candle-examples/examples/llava/constants.rs
Normal file
@ -0,0 +1,4 @@
|
||||
pub const DEFAULT_IMAGE_TOKEN: &str = "<image>";
|
||||
pub const DEFAULT_IM_START_TOKEN: &str = "<im_start>";
|
||||
pub const DEFAULT_IM_END_TOKEN: &str = "<im_end>";
|
||||
pub const IMAGE_PLACEHOLDER: &str = "<image-placeholder>";
|
114
candle-examples/examples/llava/conversation.rs
Normal file
114
candle-examples/examples/llava/conversation.rs
Normal file
@ -0,0 +1,114 @@
|
||||
pub enum SeparatorStyle {
|
||||
Two,
|
||||
Mpt,
|
||||
}
|
||||
pub struct Conversation {
|
||||
pub system: String,
|
||||
pub roles: Vec<String>,
|
||||
pub messages: Vec<(String, Option<String>)>,
|
||||
pub offset: i32,
|
||||
pub sep_style: SeparatorStyle,
|
||||
pub sep: String,
|
||||
pub sep2: Option<String>,
|
||||
pub version: String,
|
||||
}
|
||||
|
||||
impl Conversation {
|
||||
pub fn new(
|
||||
system: &str,
|
||||
roles: &[String],
|
||||
offset: i32,
|
||||
sep_style: SeparatorStyle,
|
||||
sep: &str,
|
||||
sep2: Option<&str>,
|
||||
version: &str,
|
||||
) -> Self {
|
||||
Conversation {
|
||||
system: system.to_string(),
|
||||
roles: roles.to_vec(),
|
||||
messages: Vec::new(),
|
||||
offset,
|
||||
sep_style,
|
||||
sep: sep.to_string(),
|
||||
sep2: sep2.map(|s| s.to_string()),
|
||||
version: version.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn conv_chatml_direct() -> Self {
|
||||
Conversation::new(
|
||||
"<|im_start|>system\nAnswer the questions.",
|
||||
&[
|
||||
"<|im_start|>user\n".to_string(),
|
||||
"<|im_start|>assistant\n".to_string(),
|
||||
],
|
||||
0,
|
||||
SeparatorStyle::Mpt,
|
||||
"<|im_end|>",
|
||||
None,
|
||||
"mpt",
|
||||
)
|
||||
}
|
||||
|
||||
pub fn conv_llava_v1() -> Self {
|
||||
Conversation::new(
|
||||
"A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
||||
&[
|
||||
"USER".to_string(),
|
||||
"ASSISTANT".to_string(),
|
||||
],
|
||||
0,
|
||||
SeparatorStyle::Two,
|
||||
" ",
|
||||
Some("</s>"),
|
||||
"v1"
|
||||
)
|
||||
}
|
||||
|
||||
pub fn append_message(&mut self, role: String, message: Option<&str>) {
|
||||
self.messages.push((role, message.map(|s| s.to_string())))
|
||||
}
|
||||
|
||||
pub fn append_user_message(&mut self, message: Option<&str>) {
|
||||
self.append_message(self.roles[0].clone(), message);
|
||||
}
|
||||
|
||||
pub fn append_assistant_message(&mut self, message: Option<&str>) {
|
||||
self.append_message(self.roles[1].clone(), message);
|
||||
}
|
||||
|
||||
pub fn get_prompt(&self) -> String {
|
||||
match self.sep_style {
|
||||
SeparatorStyle::Mpt => {
|
||||
let mut ret = String::new();
|
||||
ret.push_str(&self.system);
|
||||
ret.push_str(&self.sep);
|
||||
for (role, message) in &self.messages {
|
||||
ret.push_str(role);
|
||||
if let Some(message) = message {
|
||||
ret.push_str(message);
|
||||
};
|
||||
ret.push_str(&self.sep);
|
||||
}
|
||||
ret
|
||||
}
|
||||
SeparatorStyle::Two => {
|
||||
let seps = [self.sep.clone(), self.sep2.clone().unwrap()];
|
||||
let mut ret = String::new();
|
||||
ret.push_str(&self.system);
|
||||
ret.push_str(&seps[0]);
|
||||
for (i, (role, message)) in self.messages.iter().enumerate() {
|
||||
ret.push_str(role);
|
||||
if let Some(message) = message {
|
||||
ret.push_str(": "); // strictly follow the python implementation, otherwise it will cause some minor difference between tokens ^_^
|
||||
ret.push_str(message);
|
||||
ret.push_str(&seps[i % 2]);
|
||||
} else {
|
||||
ret.push(':')
|
||||
}
|
||||
}
|
||||
ret
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
317
candle-examples/examples/llava/image_processor.rs
Normal file
317
candle-examples/examples/llava/image_processor.rs
Normal file
@ -0,0 +1,317 @@
|
||||
use std::cmp::min;
|
||||
|
||||
use candle::{bail, DType, Device, Result, Tensor};
|
||||
use candle_transformers::models::llava::{
|
||||
config::{HFPreProcessorConfig, LLaVAConfig},
|
||||
utils::select_best_resolution,
|
||||
};
|
||||
use hf_hub::api::sync::Api;
|
||||
use image::{imageops::overlay, DynamicImage, GenericImageView, Rgb, RgbImage};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
//This struct is mainly for LLaVA aplications, hence it's not completely compatible with python transformer CLIPImageProcessor few several preprocess that LLaVA used, including "openai/clip-vit-large-patch14-336" and "openai/clip-vit-large-patch14".
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct ImageProcessor {
|
||||
#[serde(default = "default_size")]
|
||||
pub size: u32, // this is not the same as python transformer
|
||||
#[serde(default = "default_do_resize")]
|
||||
pub do_resize: bool,
|
||||
|
||||
//resample: u32 // 3 for PIL bicubic, equivalent to rust CatmullRom. Hence below we use CatmullRom
|
||||
#[serde(default = "default_do_center_crop")]
|
||||
pub do_center_crop: bool,
|
||||
#[serde(default = "default_crop_size")]
|
||||
pub crop_size: u32, // this is not the same as python transformer
|
||||
#[serde(default = "default_do_rescale")]
|
||||
pub do_rescale: bool,
|
||||
#[serde(default = "default_rescale_factor")]
|
||||
pub rescale_factor: f32,
|
||||
#[serde(default = "default_do_normalize")]
|
||||
pub do_normalize: bool,
|
||||
#[serde(default = "default_image_mean")]
|
||||
pub image_mean: Vec<f32>,
|
||||
#[serde(default = "default_image_std")]
|
||||
pub image_std: Vec<f32>,
|
||||
}
|
||||
|
||||
fn default_size() -> u32 {
|
||||
224
|
||||
}
|
||||
|
||||
fn default_do_resize() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn default_do_center_crop() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn default_crop_size() -> u32 {
|
||||
224
|
||||
}
|
||||
|
||||
fn default_do_rescale() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn default_rescale_factor() -> f32 {
|
||||
1.0 / 255.0
|
||||
}
|
||||
|
||||
fn default_do_normalize() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn default_image_mean() -> Vec<f32> {
|
||||
vec![0.48145466, 0.4578275, 0.40821073]
|
||||
}
|
||||
|
||||
fn default_image_std() -> Vec<f32> {
|
||||
vec![0.26862954, 0.2613026, 0.2757771]
|
||||
}
|
||||
|
||||
impl ImageProcessor {
|
||||
pub fn from_pretrained(clip_id: &str) -> Result<Self> {
|
||||
let api = Api::new().map_err(|e| candle::Error::Msg(e.to_string()))?;
|
||||
let api = api.model(clip_id.to_string());
|
||||
let config_filename = api
|
||||
.get("preprocessor_config.json")
|
||||
.map_err(|e| candle::Error::Msg(e.to_string()))?;
|
||||
let image_processor =
|
||||
serde_json::from_slice(&std::fs::read(config_filename).map_err(candle::Error::Io)?)
|
||||
.map_err(|e| candle::Error::Msg(e.to_string()))?;
|
||||
Ok(image_processor)
|
||||
}
|
||||
|
||||
pub fn from_hf_preprocessor_config(hf_preprocessor_config: &HFPreProcessorConfig) -> Self {
|
||||
Self {
|
||||
size: hf_preprocessor_config.size["shortest_edge"] as u32,
|
||||
do_resize: hf_preprocessor_config.do_resize,
|
||||
do_center_crop: hf_preprocessor_config.do_center_crop,
|
||||
crop_size: hf_preprocessor_config.crop_size["height"] as u32,
|
||||
do_rescale: hf_preprocessor_config.do_rescale,
|
||||
rescale_factor: hf_preprocessor_config.rescale_factor,
|
||||
do_normalize: hf_preprocessor_config.do_normalize,
|
||||
image_mean: hf_preprocessor_config.image_mean.clone(),
|
||||
image_std: hf_preprocessor_config.image_std.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
///shortest edge to self.resize, other edge is resized to maintain aspect ratio
|
||||
pub fn resize(&self, image: &DynamicImage) -> DynamicImage {
|
||||
let (width, height) = image.dimensions();
|
||||
let size = self.size;
|
||||
if width == size && height == size {
|
||||
image.clone()
|
||||
} else {
|
||||
let (new_width, new_height) = if width < height {
|
||||
(
|
||||
size,
|
||||
(((size * height) as f32) / width as f32).ceil() as u32,
|
||||
)
|
||||
} else {
|
||||
(
|
||||
(((size * width) as f32) / height as f32).ceil() as u32,
|
||||
size,
|
||||
)
|
||||
};
|
||||
image.resize(
|
||||
new_width,
|
||||
new_height,
|
||||
image::imageops::FilterType::CatmullRom,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn center_crop(&self, image: &DynamicImage) -> DynamicImage {
|
||||
let (width, height) = image.dimensions();
|
||||
let crop_size = self.crop_size;
|
||||
let (left, top) = calculate_middle((width, height), (crop_size, crop_size));
|
||||
image.crop_imm(left, top, crop_size, crop_size)
|
||||
}
|
||||
|
||||
pub fn to_tensor(&self, image: &DynamicImage) -> Result<Tensor> {
|
||||
let img = image.to_rgb8().into_raw();
|
||||
let (width, height) = image.dimensions();
|
||||
Tensor::from_vec(img, (height as usize, width as usize, 3), &Device::Cpu)?
|
||||
.to_dtype(DType::F32) // only for internal compute
|
||||
}
|
||||
|
||||
pub fn rescale(&self, tensor: &Tensor) -> Result<Tensor> {
|
||||
let rescale_factor = self.rescale_factor as f64;
|
||||
tensor.affine(rescale_factor, 0.0)
|
||||
}
|
||||
|
||||
pub fn normalize(&self, tensor: &Tensor) -> Result<Tensor> {
|
||||
let image_mean = self.image_mean.clone();
|
||||
let image_std = self.image_std.clone();
|
||||
let mean = Tensor::from_vec(image_mean, (3,), &Device::Cpu)?;
|
||||
let std = Tensor::from_vec(image_std, (3,), &Device::Cpu)?;
|
||||
tensor.broadcast_sub(&mean)?.broadcast_div(&std)
|
||||
}
|
||||
|
||||
pub fn to_channel_dimension_format(&self, tensor: &Tensor) -> Result<Tensor> {
|
||||
tensor.permute((2, 0, 1))
|
||||
}
|
||||
|
||||
pub fn preprocess(&self, image: &DynamicImage) -> Result<Tensor> {
|
||||
let image = if self.do_resize {
|
||||
self.resize(image)
|
||||
} else {
|
||||
image.clone()
|
||||
};
|
||||
let image = if self.do_center_crop {
|
||||
self.center_crop(&image)
|
||||
} else {
|
||||
image
|
||||
};
|
||||
let tensor = self.to_tensor(&image)?;
|
||||
let tensor = if self.do_rescale {
|
||||
self.rescale(&tensor)?
|
||||
} else {
|
||||
tensor
|
||||
};
|
||||
let tensor = if self.do_normalize {
|
||||
self.normalize(&tensor)?
|
||||
} else {
|
||||
tensor
|
||||
};
|
||||
self.to_channel_dimension_format(&tensor)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn calculate_middle(image_size: (u32, u32), center_size: (u32, u32)) -> (u32, u32) {
|
||||
let (width, height) = image_size;
|
||||
let (center_width, center_height) = center_size;
|
||||
let left = if width <= center_width {
|
||||
0
|
||||
} else {
|
||||
((width as f32 - center_width as f32) / 2.0).ceil() as u32
|
||||
};
|
||||
let top = if height <= center_height {
|
||||
0
|
||||
} else {
|
||||
((height as f32 - center_height as f32) / 2.0).ceil() as u32
|
||||
};
|
||||
(left, top)
|
||||
}
|
||||
|
||||
pub fn process_image(
|
||||
image: &DynamicImage,
|
||||
processor: &ImageProcessor,
|
||||
llava_config: &LLaVAConfig,
|
||||
) -> candle::Result<Tensor> {
|
||||
if llava_config.image_aspect_ratio == *"square" {
|
||||
processor.preprocess(image)?.unsqueeze(0)
|
||||
} else if llava_config.image_aspect_ratio == *"anyres" {
|
||||
process_anyres_image(image, processor, &llava_config.image_grid_pinpoints)
|
||||
} else if llava_config.image_aspect_ratio == *"pad" {
|
||||
process_pad_image(image, processor)
|
||||
} else {
|
||||
bail!("Invalid image aspect ratio")
|
||||
}
|
||||
}
|
||||
|
||||
fn process_pad_image(image: &DynamicImage, processor: &ImageProcessor) -> Result<Tensor> {
|
||||
let mean_color = processor
|
||||
.image_mean
|
||||
.iter()
|
||||
.map(|x| ((*x) * 255.0) as u8)
|
||||
.collect::<Vec<u8>>();
|
||||
let mean_color = Rgb::from([mean_color[0], mean_color[1], mean_color[2]]);
|
||||
let image_padded = expand2square(image, mean_color);
|
||||
processor.preprocess(&image_padded)
|
||||
}
|
||||
|
||||
fn process_anyres_image(
|
||||
image: &DynamicImage,
|
||||
processor: &ImageProcessor,
|
||||
grid_pinpoints: &[(u32, u32)],
|
||||
) -> Result<Tensor> {
|
||||
let original_size = image.dimensions();
|
||||
let best_resolution = select_best_resolution(original_size, grid_pinpoints);
|
||||
let image_padded = resize_and_pad_image(image, best_resolution);
|
||||
let image_original_resize = image.resize_exact(
|
||||
processor.size,
|
||||
processor.size,
|
||||
image::imageops::FilterType::CatmullRom,
|
||||
);
|
||||
let mut patches = vec![image_original_resize];
|
||||
for patch in divide_to_patches(&image_padded, processor.crop_size) {
|
||||
patches.push(patch);
|
||||
}
|
||||
let tensors = patches
|
||||
.iter()
|
||||
.map(|patch| processor.preprocess(patch))
|
||||
.collect::<Result<Vec<Tensor>>>()?;
|
||||
Tensor::stack(&tensors, 0)
|
||||
}
|
||||
|
||||
fn expand2square(image: &DynamicImage, background_color: Rgb<u8>) -> DynamicImage {
|
||||
let (width, height) = image.dimensions();
|
||||
match width.cmp(&height) {
|
||||
std::cmp::Ordering::Less => {
|
||||
let mut new_image =
|
||||
DynamicImage::from(RgbImage::from_pixel(height, height, background_color));
|
||||
overlay(&mut new_image, image, ((height - width) / 2) as i64, 0);
|
||||
new_image
|
||||
}
|
||||
std::cmp::Ordering::Equal => image.clone(),
|
||||
std::cmp::Ordering::Greater => {
|
||||
let mut new_image =
|
||||
DynamicImage::from(RgbImage::from_pixel(width, width, background_color));
|
||||
overlay(&mut new_image, image, 0, ((width - height) / 2) as i64);
|
||||
new_image
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn resize_and_pad_image(image: &DynamicImage, target_resolution: (u32, u32)) -> DynamicImage {
|
||||
let (original_width, original_height) = image.dimensions();
|
||||
let original_width_f = original_width as f32;
|
||||
let original_height_f = original_height as f32;
|
||||
let (target_width, target_height) = target_resolution;
|
||||
let target_width_f = target_width as f32;
|
||||
let target_height_f = target_height as f32;
|
||||
let scale_w = target_width_f / original_width_f;
|
||||
let scale_h = target_height_f / original_height_f;
|
||||
let (new_width, new_height) = if scale_w < scale_h {
|
||||
(
|
||||
target_width,
|
||||
min((original_height_f * scale_w).ceil() as u32, target_height),
|
||||
)
|
||||
} else {
|
||||
(
|
||||
min((original_width_f * scale_h).ceil() as u32, target_width),
|
||||
target_height,
|
||||
)
|
||||
};
|
||||
let resized_image = image.resize_exact(
|
||||
new_width,
|
||||
new_height,
|
||||
image::imageops::FilterType::CatmullRom,
|
||||
);
|
||||
let mut new_image = DynamicImage::new_rgb8(target_width, target_height);
|
||||
let (paste_x, paste_y) =
|
||||
calculate_middle((target_width, target_height), (new_width, new_height));
|
||||
overlay(
|
||||
&mut new_image,
|
||||
&resized_image,
|
||||
paste_x.into(),
|
||||
paste_y.into(),
|
||||
);
|
||||
new_image
|
||||
}
|
||||
|
||||
fn divide_to_patches(image: &DynamicImage, patch_size: u32) -> Vec<DynamicImage> {
|
||||
let (width, height) = image.dimensions();
|
||||
let mut patches = Vec::new();
|
||||
for y in (0..height).step_by(patch_size as usize) {
|
||||
for x in (0..width).step_by(patch_size as usize) {
|
||||
let patch = image.crop_imm(x, y, patch_size, patch_size);
|
||||
patches.push(patch);
|
||||
}
|
||||
}
|
||||
patches
|
||||
}
|
316
candle-examples/examples/llava/main.rs
Normal file
316
candle-examples/examples/llava/main.rs
Normal file
@ -0,0 +1,316 @@
|
||||
pub mod constants;
|
||||
pub mod conversation;
|
||||
pub mod image_processor;
|
||||
|
||||
use candle_transformers::generation::{LogitsProcessor, Sampling};
|
||||
use candle_transformers::models::llama::Cache;
|
||||
|
||||
use anyhow::{bail, Error as E, Result};
|
||||
use candle::{DType, Device, IndexOp, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::models::llava::config::{
|
||||
HFGenerationConfig, HFLLaVAConfig, HFPreProcessorConfig,
|
||||
};
|
||||
use candle_transformers::models::llava::{config::LLaVAConfig, LLaVA};
|
||||
use clap::Parser;
|
||||
use constants::*;
|
||||
use conversation::Conversation;
|
||||
use hf_hub::api::sync::Api;
|
||||
use image_processor::{process_image, ImageProcessor};
|
||||
use std::io::Write;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about,long_about=None)]
|
||||
struct Args {
|
||||
#[arg(long, default_value = "llava-hf/llava-v1.6-vicuna-7b-hf")]
|
||||
model_path: String,
|
||||
#[arg(long, default_value = "tokenizer/tokenizer.json")]
|
||||
tokenizer_path: String,
|
||||
#[arg(long)]
|
||||
model_base: Option<String>,
|
||||
#[arg(long)]
|
||||
image_file: String, // Required
|
||||
#[arg(long)]
|
||||
conv_mode: Option<String>,
|
||||
#[arg(long, default_value_t = 0.2)]
|
||||
temperature: f32,
|
||||
#[arg(long, default_value_t = 512)]
|
||||
max_new_tokens: usize,
|
||||
#[arg(long, action)]
|
||||
hf: bool,
|
||||
#[arg(long, action)]
|
||||
cpu: bool,
|
||||
#[arg(long, action)]
|
||||
no_kv_cache: bool,
|
||||
#[arg(long)]
|
||||
prompt: String,
|
||||
/// The seed to use when generating random samples. Copy from candle llama. Not exist in python llava.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
}
|
||||
|
||||
//from https://github.com/huggingface/candle/blob/main/candle-examples/examples/clip/main.rs
|
||||
fn load_image<T: AsRef<std::path::Path>>(
|
||||
path: T,
|
||||
processor: &ImageProcessor,
|
||||
llava_config: &LLaVAConfig,
|
||||
dtype: DType,
|
||||
) -> Result<((u32, u32), Tensor)> {
|
||||
let img = image::io::Reader::open(path)?.decode()?;
|
||||
let img_tensor = process_image(&img, processor, llava_config)?;
|
||||
Ok(((img.width(), img.height()), img_tensor.to_dtype(dtype)?))
|
||||
}
|
||||
|
||||
fn get_model_name_from_path(model_path: &str) -> String {
|
||||
let model_paths: Vec<String> = model_path
|
||||
.trim_matches('/')
|
||||
.split('/')
|
||||
.map(|s| s.to_string())
|
||||
.collect();
|
||||
if model_paths.last().unwrap().starts_with("checkpoint-") {
|
||||
format!(
|
||||
"{}_{}",
|
||||
model_paths[model_paths.len() - 2],
|
||||
model_paths.last().unwrap()
|
||||
)
|
||||
} else {
|
||||
model_paths.last().unwrap().to_string()
|
||||
}
|
||||
}
|
||||
|
||||
fn duplicate_vec<T>(vec: &[T], n: usize) -> Vec<T>
|
||||
where
|
||||
T: Clone,
|
||||
{
|
||||
let mut res = Vec::new();
|
||||
for _ in 0..n {
|
||||
res.extend(vec.to_owned());
|
||||
}
|
||||
res
|
||||
}
|
||||
|
||||
fn insert_separator<T>(x: Vec<Vec<T>>, sep: Vec<T>) -> Vec<Vec<T>>
|
||||
where
|
||||
T: Clone,
|
||||
{
|
||||
let sep = vec![sep];
|
||||
let sep = duplicate_vec(&sep, x.len());
|
||||
let mut res = x
|
||||
.iter()
|
||||
.zip(sep.iter())
|
||||
.flat_map(|(x, y)| vec![x.clone(), y.clone()])
|
||||
.collect::<Vec<Vec<T>>>();
|
||||
res.pop();
|
||||
res
|
||||
}
|
||||
|
||||
fn tokenizer_image_token(
|
||||
prompt: &str,
|
||||
tokenizer: &Tokenizer,
|
||||
image_token_index: i64,
|
||||
llava_config: &LLaVAConfig,
|
||||
) -> Result<Tensor> {
|
||||
let prompt_chunks = prompt
|
||||
.split("<image>")
|
||||
.map(|s| {
|
||||
tokenizer
|
||||
.encode(s, true)
|
||||
.unwrap()
|
||||
.get_ids()
|
||||
.to_vec()
|
||||
.iter()
|
||||
.map(|x| *x as i64)
|
||||
.collect()
|
||||
})
|
||||
.collect::<Vec<Vec<i64>>>();
|
||||
let mut input_ids = Vec::new();
|
||||
let mut offset = 0;
|
||||
if !prompt_chunks.is_empty()
|
||||
&& !prompt_chunks[0].is_empty()
|
||||
&& prompt_chunks[0][0] == llava_config.bos_token_id as i64
|
||||
{
|
||||
offset = 1;
|
||||
input_ids.push(prompt_chunks[0][0]);
|
||||
}
|
||||
|
||||
for x in insert_separator(
|
||||
prompt_chunks,
|
||||
duplicate_vec(&[image_token_index], offset + 1),
|
||||
)
|
||||
.iter()
|
||||
{
|
||||
input_ids.extend(x[1..].to_vec())
|
||||
}
|
||||
let input_len = input_ids.len();
|
||||
Tensor::from_vec(input_ids, (1, input_len), &Device::Cpu).map_err(E::msg)
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let mut args = Args::parse();
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
println!("Start loading model");
|
||||
let api = Api::new()?;
|
||||
let api = api.model(args.model_path.clone());
|
||||
let (llava_config, tokenizer, clip_vision_config, image_processor) = if args.hf {
|
||||
let config_filename = api.get("config.json")?;
|
||||
let hf_llava_config: HFLLaVAConfig =
|
||||
serde_json::from_slice(&std::fs::read(config_filename)?)?;
|
||||
let generation_config_filename = api.get("generation_config.json")?;
|
||||
let generation_config: HFGenerationConfig =
|
||||
serde_json::from_slice(&std::fs::read(generation_config_filename)?)?;
|
||||
let preprocessor_config_filename = api.get("preprocessor_config.json")?;
|
||||
let preprocessor_config: HFPreProcessorConfig =
|
||||
serde_json::from_slice(&std::fs::read(preprocessor_config_filename)?)?;
|
||||
let llava_config =
|
||||
hf_llava_config.to_llava_config(&generation_config, &preprocessor_config);
|
||||
let tokenizer_filename = api.get("tokenizer.json")?;
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
let clip_vision_config = hf_llava_config.to_clip_vision_config();
|
||||
(
|
||||
llava_config,
|
||||
tokenizer,
|
||||
Some(clip_vision_config),
|
||||
ImageProcessor::from_hf_preprocessor_config(&preprocessor_config),
|
||||
)
|
||||
} else {
|
||||
let config_filename = api.get("config.json")?;
|
||||
let llava_config: LLaVAConfig = serde_json::from_slice(&std::fs::read(config_filename)?)?;
|
||||
let tokenizer = Tokenizer::from_file(&args.tokenizer_path)
|
||||
.map_err(|e| E::msg(format!("Error loading {}: {}", &args.tokenizer_path, e)))?;
|
||||
(
|
||||
llava_config.clone(),
|
||||
tokenizer,
|
||||
None,
|
||||
ImageProcessor::from_pretrained(&llava_config.mm_vision_tower.unwrap())?,
|
||||
)
|
||||
};
|
||||
|
||||
let llama_config = llava_config.to_llama_config();
|
||||
let dtype: DType = match llava_config.torch_dtype.as_str() {
|
||||
"float16" => DType::F16,
|
||||
"bfloat16" => DType::BF16,
|
||||
_ => bail!("unsupported dtype"),
|
||||
};
|
||||
|
||||
let eos_token_id = llava_config.eos_token_id;
|
||||
|
||||
println!("setting kv cache");
|
||||
let mut cache = Cache::new(!args.no_kv_cache, dtype, &llama_config, &device)?;
|
||||
|
||||
println!("loading model weights");
|
||||
|
||||
let weight_filenames =
|
||||
candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?;
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&weight_filenames, dtype, &device)? };
|
||||
let llava: LLaVA = LLaVA::load(vb, &llava_config, clip_vision_config)?;
|
||||
|
||||
println!("generating conv template");
|
||||
let image_token_se = format!(
|
||||
"{}{}{}",
|
||||
DEFAULT_IM_START_TOKEN, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_END_TOKEN
|
||||
);
|
||||
let qs = if args.prompt.contains(IMAGE_PLACEHOLDER) {
|
||||
if llava_config.mm_use_im_start_end {
|
||||
args.prompt.replace(IMAGE_PLACEHOLDER, &image_token_se)
|
||||
} else {
|
||||
args.prompt.replace(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN)
|
||||
}
|
||||
} else if llava_config.mm_use_im_start_end {
|
||||
format!("{}\n{}", image_token_se, args.prompt)
|
||||
} else {
|
||||
format!("{}\n{}", DEFAULT_IMAGE_TOKEN, args.prompt)
|
||||
};
|
||||
|
||||
let model_name = get_model_name_from_path(&args.model_path).to_lowercase();
|
||||
let conv_mode = if model_name.contains("llama-2") {
|
||||
"llava_llama_2"
|
||||
} else if model_name.contains("mistral") {
|
||||
"mistral_instruct"
|
||||
} else if model_name.contains("v1.6-34b") {
|
||||
"chatml_direct"
|
||||
} else if model_name.contains("v1") {
|
||||
"llava_v1"
|
||||
} else if model_name.contains("mpt") {
|
||||
"mpt"
|
||||
} else {
|
||||
"llava_v0"
|
||||
};
|
||||
if args.conv_mode.is_some() && args.conv_mode.as_deref() != Some(conv_mode) {
|
||||
println!(
|
||||
"Warning: the model is trained with {}, but you are using {}",
|
||||
conv_mode,
|
||||
args.conv_mode.as_deref().unwrap()
|
||||
);
|
||||
} else {
|
||||
args.conv_mode = Some(conv_mode.to_string());
|
||||
}
|
||||
|
||||
let mut conv = match args.conv_mode {
|
||||
Some(conv_mode) => match conv_mode.as_str() {
|
||||
"chatml_direct" => Conversation::conv_chatml_direct(),
|
||||
"llava_v1" => Conversation::conv_llava_v1(),
|
||||
_ => todo!("not implement yet"),
|
||||
},
|
||||
None => bail!("conv_mode is required"),
|
||||
};
|
||||
conv.append_user_message(Some(&qs));
|
||||
conv.append_assistant_message(None);
|
||||
let prompt = conv.get_prompt();
|
||||
println!("loading image");
|
||||
let (image_size, image_tensor) =
|
||||
load_image(&args.image_file, &image_processor, &llava_config, dtype)
|
||||
.map_err(|e| E::msg(format!("Error loading {}: {}", &args.image_file, e)))?;
|
||||
let image_tensor = image_tensor.to_device(&device)?;
|
||||
|
||||
let mut logits_processor = {
|
||||
let temperature = f64::from(args.temperature);
|
||||
let sampling = if temperature <= 0. {
|
||||
Sampling::ArgMax
|
||||
} else {
|
||||
Sampling::All { temperature }
|
||||
};
|
||||
LogitsProcessor::from_sampling(args.seed, sampling)
|
||||
};
|
||||
|
||||
// get input tokens
|
||||
let tokens = tokenizer_image_token(
|
||||
&prompt,
|
||||
&tokenizer,
|
||||
llava_config.image_token_index as i64,
|
||||
&llava_config,
|
||||
)?;
|
||||
let mut input_embeds =
|
||||
llava.prepare_inputs_labels_for_multimodal(&tokens, &[image_tensor], &[image_size])?;
|
||||
//inference loop, based on https://github.com/huggingface/candle/blob/main/candle-examples/examples/llama/main.rs
|
||||
let mut tokenizer = candle_examples::token_output_stream::TokenOutputStream::new(tokenizer);
|
||||
let mut index_pos = 0;
|
||||
for index in 0..args.max_new_tokens {
|
||||
let (_, input_embeds_len, _) = input_embeds.dims3()?;
|
||||
let (context_size, context_index) = if cache.use_kv_cache && index > 0 {
|
||||
(1, index_pos)
|
||||
} else {
|
||||
(input_embeds_len, 0)
|
||||
};
|
||||
let input = input_embeds.i((.., input_embeds_len.saturating_sub(context_size).., ..))?;
|
||||
let logits = llava.forward(&input, context_index, &mut cache)?; //[1,32000]
|
||||
let logits = logits.squeeze(0)?;
|
||||
let (_, input_len, _) = input.dims3()?;
|
||||
index_pos += input_len;
|
||||
let next_token = logits_processor.sample(&logits)?;
|
||||
let next_token_tensor = Tensor::from_vec(vec![next_token], 1, &device)?;
|
||||
let next_embeds = llava.llama.embed(&next_token_tensor)?.unsqueeze(0)?;
|
||||
input_embeds = Tensor::cat(&[input_embeds, next_embeds], 1)?;
|
||||
if next_token == eos_token_id as u32 {
|
||||
break;
|
||||
}
|
||||
if let Some(t) = tokenizer.next_token(next_token)? {
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
}
|
||||
if let Some(rest) = tokenizer.decode_rest().map_err(E::msg)? {
|
||||
print!("{rest}");
|
||||
}
|
||||
Ok(())
|
||||
}
|
40
candle-examples/examples/llava/readme.md
Normal file
40
candle-examples/examples/llava/readme.md
Normal file
@ -0,0 +1,40 @@
|
||||
# candle-llava
|
||||
|
||||
LLaVA (Large Language-and-Vision Assistant) is an end-to-end trained large
|
||||
multimodal model. This example is from [candle-llava](https://github.com/chenwanqq/candle-llava)
|
||||
|
||||
The code is based on [https://github.com/haotian-liu/LLaVA](https://github.com/haotian-liu/LLaVA), Hence the llava-hf version of config may perform differently.
|
||||
|
||||
## model zoo
|
||||
* [liuhaotian/LLaVA](https://huggingface.co/liuhaotian)
|
||||
* [llava-hf](https://huggingface.co/llava-hf)
|
||||
|
||||
Right now this has been tested on `liuhaotian/llava-v1.6-vicuna-7b` and
|
||||
`llava-hf/llava-v1.6-vicuna-7b-hf`. Memory usage might have room for optimization.
|
||||
|
||||
## Tokenizer Setup
|
||||
The llava-hf models contain a `tokenizer.json` file so can be used directly with
|
||||
the `-hf` command line flag.
|
||||
|
||||
For the original llava models, you can use the following code to generate the `tokenizer.json` file.
|
||||
|
||||
```bash
|
||||
conda create -n llava python=3.10
|
||||
pip install transformers protobuf
|
||||
conda activate llava
|
||||
python -c "from transformers import AutoTokenizer;tokenizer=AutoTokenizer.from_pretrained('liuhaotian/llava-v1.6-vicuna-7b');tokenizer.save_pretrained('tokenizer')"
|
||||
```
|
||||
Then the `tokenizer.json` file should be in `tokenizer/tokenizer.json` (which is the default path).
|
||||
|
||||
|
||||
## eval
|
||||
|
||||
```bash
|
||||
cargo run --example llava --features cuda -- --image-file "llava_logo.png" --prompt "is this a cat?" --hf # default args, use llava-hf/llava-v1.6-vicuna-7b-hf. image-file is required^_^
|
||||
cargo run --example llava --features cuda -- --model-path liuhaotian/llava-v1.6-vicuna-7b --image-file "llava_logo.png" --prompt "is this a cat?" # use liuhaotian/llava-v1.6-vicuna-7b, tokenizer setup should be done
|
||||
```
|
||||
|
||||
## Major Limitations
|
||||
1. Currently only support llama-2/vicuna llm. Haven't supoort Mistral yet.
|
||||
2. There are some ops like split, nonzero and where are not supported by candle.
|
||||
3. Lack of quantization and LoRA support.
|
Reference in New Issue
Block a user