mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
@ -36,6 +36,7 @@ serde_json = { workspace = true }
|
|||||||
symphonia = { version = "0.5.3", features = ["all"], optional = true }
|
symphonia = { version = "0.5.3", features = ["all"], optional = true }
|
||||||
tokenizers = { workspace = true, features = ["onig"] }
|
tokenizers = { workspace = true, features = ["onig"] }
|
||||||
cpal = { version = "0.15.2", optional = true }
|
cpal = { version = "0.15.2", optional = true }
|
||||||
|
pdf2image = { version = "0.1.2" , optional = true}
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
anyhow = { workspace = true }
|
anyhow = { workspace = true }
|
||||||
@ -117,3 +118,7 @@ required-features = ["depth_anything_v2"]
|
|||||||
[[example]]
|
[[example]]
|
||||||
name = "silero-vad"
|
name = "silero-vad"
|
||||||
required-features = ["onnx"]
|
required-features = ["onnx"]
|
||||||
|
|
||||||
|
[[example]]
|
||||||
|
name = "colpali"
|
||||||
|
required-features = ["pdf2image"]
|
18
candle-examples/examples/colpali/README.md
Normal file
18
candle-examples/examples/colpali/README.md
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
# Colpali
|
||||||
|
|
||||||
|
[HuggingFace Model Card](https://huggingface.co/vidore/colpali-v1.2-merged)
|
||||||
|
|
||||||
|
```
|
||||||
|
wget https://arxiv.org/pdf/1706.03762.pdf
|
||||||
|
cargo run --features cuda,pdf2image --release --example colpali -- --prompt "What is Positional Encoding" --pdf "1706.03762.pdf"
|
||||||
|
```
|
||||||
|
|
||||||
|
```
|
||||||
|
Prompt: what is position encoding?
|
||||||
|
top 3 page numbers that contain similarity to the prompt
|
||||||
|
-----------------------------------
|
||||||
|
Page: 6
|
||||||
|
Page: 11
|
||||||
|
Page: 15
|
||||||
|
-----------------------------------
|
||||||
|
```
|
268
candle-examples/examples/colpali/main.rs
Normal file
268
candle-examples/examples/colpali/main.rs
Normal file
@ -0,0 +1,268 @@
|
|||||||
|
use anyhow::{Error as E, Result};
|
||||||
|
use candle::{DType, Device, Tensor};
|
||||||
|
use candle_nn::VarBuilder;
|
||||||
|
use candle_transformers::models::colpali::Model;
|
||||||
|
use candle_transformers::models::{colpali, paligemma};
|
||||||
|
use clap::Parser;
|
||||||
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
|
use image::DynamicImage;
|
||||||
|
use pdf2image::{RenderOptionsBuilder, PDF};
|
||||||
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
|
struct PageRetriever {
|
||||||
|
model: Model,
|
||||||
|
config: paligemma::Config,
|
||||||
|
pdf: PDF,
|
||||||
|
device: Device,
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
range: pdf2image::Pages,
|
||||||
|
batch_size: usize,
|
||||||
|
top_k: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PageRetriever {
|
||||||
|
fn new(
|
||||||
|
model: Model,
|
||||||
|
config: paligemma::Config,
|
||||||
|
pdf: PDF,
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
device: &Device,
|
||||||
|
range: Option<pdf2image::Pages>,
|
||||||
|
batch_size: usize,
|
||||||
|
top_k: usize,
|
||||||
|
) -> Self {
|
||||||
|
let page_count = pdf.page_count();
|
||||||
|
Self {
|
||||||
|
model,
|
||||||
|
config,
|
||||||
|
pdf,
|
||||||
|
device: device.clone(),
|
||||||
|
tokenizer,
|
||||||
|
range: range.unwrap_or_else(|| pdf2image::Pages::Range(1..=page_count)),
|
||||||
|
batch_size,
|
||||||
|
top_k,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_images_from_pdf(&self) -> Result<Vec<DynamicImage>> {
|
||||||
|
let pages = self
|
||||||
|
.pdf
|
||||||
|
.render(self.range.clone(), RenderOptionsBuilder::default().build()?)?;
|
||||||
|
Ok(pages)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn tokenize_batch(&self, prompts: Vec<&str>) -> Result<Tensor> {
|
||||||
|
let tokens = self.tokenizer.encode_batch(prompts, true).map_err(E::msg)?;
|
||||||
|
let token_ids = tokens
|
||||||
|
.iter()
|
||||||
|
.map(|tokens| {
|
||||||
|
let tokens = tokens.get_ids().to_vec();
|
||||||
|
Tensor::new(tokens.as_slice(), &self.device)
|
||||||
|
})
|
||||||
|
.collect::<candle::Result<Vec<_>>>()?;
|
||||||
|
let input = Tensor::stack(&token_ids, 0)?;
|
||||||
|
Ok(input)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn images_to_tensor(
|
||||||
|
&self,
|
||||||
|
pages: &[DynamicImage],
|
||||||
|
image_size: usize,
|
||||||
|
) -> anyhow::Result<Tensor> {
|
||||||
|
let mut images = vec![];
|
||||||
|
for page in pages.iter() {
|
||||||
|
let img = page.resize_to_fill(
|
||||||
|
image_size as u32,
|
||||||
|
image_size as u32,
|
||||||
|
image::imageops::FilterType::Triangle,
|
||||||
|
);
|
||||||
|
let img = img.to_rgb8();
|
||||||
|
let img = img.into_raw();
|
||||||
|
let img = Tensor::from_vec(img, (image_size, image_size, 3), &Device::Cpu)?
|
||||||
|
.permute((2, 0, 1))?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.affine(2. / 255., -1.)?;
|
||||||
|
images.push(img);
|
||||||
|
}
|
||||||
|
let images = Tensor::stack(&images, 0)?;
|
||||||
|
Ok(images)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn retrieve(&mut self, prompt: &str) -> Result<Vec<usize>> {
|
||||||
|
let dtype = if self.device.is_cuda() {
|
||||||
|
DType::BF16
|
||||||
|
} else {
|
||||||
|
DType::F32
|
||||||
|
};
|
||||||
|
|
||||||
|
let dummy_prompt: &str = "Describe the image";
|
||||||
|
|
||||||
|
let input = self.tokenize_batch(vec![prompt])?;
|
||||||
|
let dummy_input = self.tokenize_batch(vec![dummy_prompt])?;
|
||||||
|
|
||||||
|
let pages = self.get_images_from_pdf()?;
|
||||||
|
let mut all_scores = Vec::new();
|
||||||
|
for batch in pages.chunks(self.batch_size) {
|
||||||
|
let page_images = self
|
||||||
|
.images_to_tensor(batch, self.config.vision_config.image_size)?
|
||||||
|
.to_device(&self.device)?
|
||||||
|
.to_dtype(dtype)?;
|
||||||
|
let dummy_input = dummy_input.repeat((page_images.dims()[0], 0))?;
|
||||||
|
|
||||||
|
let image_embeddings = self.model.forward_images(&page_images, &dummy_input)?;
|
||||||
|
let text_embeddings = self.model.forward_text(&input)?;
|
||||||
|
|
||||||
|
let scores = text_embeddings
|
||||||
|
.unsqueeze(1)?
|
||||||
|
.broadcast_matmul(&image_embeddings.unsqueeze(0)?.transpose(3, 2)?)?
|
||||||
|
.max(3)?
|
||||||
|
.sum(2)?;
|
||||||
|
let batch_scores: Vec<f32> = scores
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.to_vec2()?
|
||||||
|
.into_iter()
|
||||||
|
.flatten()
|
||||||
|
.collect();
|
||||||
|
all_scores.extend(batch_scores);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut indices: Vec<usize> = (0..all_scores.len()).collect();
|
||||||
|
indices.sort_by(|a, b| all_scores[*b].partial_cmp(&all_scores[*a]).unwrap());
|
||||||
|
|
||||||
|
let top_k_indices = indices[0..self.top_k].to_vec();
|
||||||
|
|
||||||
|
Ok(top_k_indices)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Parser, Debug)]
|
||||||
|
#[command(author, version, about, long_about = None)]
|
||||||
|
struct Args {
|
||||||
|
/// Run on CPU rather than on GPU.
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
|
||||||
|
/// Enable tracing (generates a trace-timestamp.json file).
|
||||||
|
#[arg(long)]
|
||||||
|
tracing: bool,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
prompt: String,
|
||||||
|
|
||||||
|
/// number of top pages to show.
|
||||||
|
#[arg(long, default_value_t = 3)]
|
||||||
|
top_k: usize,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
model_id: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long, default_value = "main")]
|
||||||
|
revision: String,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
tokenizer_file: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
weight_files: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
pdf: String,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
start: Option<u32>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
end: Option<u32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() -> Result<()> {
|
||||||
|
use tracing_chrome::ChromeLayerBuilder;
|
||||||
|
use tracing_subscriber::prelude::*;
|
||||||
|
|
||||||
|
let args = Args::parse();
|
||||||
|
let _guard = if args.tracing {
|
||||||
|
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||||
|
tracing_subscriber::registry().with(chrome_layer).init();
|
||||||
|
Some(guard)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
println!(
|
||||||
|
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||||
|
candle::utils::with_avx(),
|
||||||
|
candle::utils::with_neon(),
|
||||||
|
candle::utils::with_simd128(),
|
||||||
|
candle::utils::with_f16c()
|
||||||
|
);
|
||||||
|
|
||||||
|
let api = Api::new()?;
|
||||||
|
let model_id = match &args.model_id {
|
||||||
|
Some(model_id) => model_id.to_string(),
|
||||||
|
None => "vidore/colpali-v1.2-merged".to_string(),
|
||||||
|
};
|
||||||
|
let repo = api.repo(Repo::with_revision(
|
||||||
|
model_id,
|
||||||
|
RepoType::Model,
|
||||||
|
args.revision,
|
||||||
|
));
|
||||||
|
|
||||||
|
let tokenizer_filename = match args.tokenizer_file {
|
||||||
|
Some(file) => std::path::PathBuf::from(file),
|
||||||
|
None => api
|
||||||
|
.repo(Repo::with_revision(
|
||||||
|
"vidore/colpali".to_string(),
|
||||||
|
RepoType::Model,
|
||||||
|
"main".to_string(),
|
||||||
|
))
|
||||||
|
.get("tokenizer.json")?,
|
||||||
|
};
|
||||||
|
|
||||||
|
let filenames = match args.weight_files {
|
||||||
|
Some(files) => files
|
||||||
|
.split(',')
|
||||||
|
.map(std::path::PathBuf::from)
|
||||||
|
.collect::<Vec<_>>(),
|
||||||
|
None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
|
||||||
|
};
|
||||||
|
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
|
||||||
|
let config: paligemma::Config = paligemma::Config::paligemma_3b_448();
|
||||||
|
|
||||||
|
println!("retrieved the files in {:?}", start.elapsed());
|
||||||
|
|
||||||
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
|
let device = candle_examples::device(false)?;
|
||||||
|
let dtype = if device.is_cuda() {
|
||||||
|
DType::BF16
|
||||||
|
} else {
|
||||||
|
DType::F32
|
||||||
|
};
|
||||||
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||||
|
let model = colpali::Model::new(&config, vb)?;
|
||||||
|
|
||||||
|
let pdf = PDF::from_file(args.pdf)?;
|
||||||
|
|
||||||
|
// check if start and end given in arg
|
||||||
|
let range = if let (Some(start), Some(end)) = (args.start, args.end) {
|
||||||
|
pdf2image::Pages::Range(start..=end)
|
||||||
|
} else {
|
||||||
|
pdf2image::Pages::Range(1..=pdf.page_count()) // can use pdf2image::Pages::All but there is a bug in the library which causes the first page to rendered twice.
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut retriever =
|
||||||
|
PageRetriever::new(model, config, pdf, tokenizer, &device, Some(range), 4, 3);
|
||||||
|
let top_k_indices = retriever.retrieve(&args.prompt)?;
|
||||||
|
|
||||||
|
println!("Prompt: {}", args.prompt);
|
||||||
|
println!(
|
||||||
|
"top {} page numbers that contain similarity to the prompt",
|
||||||
|
retriever.top_k
|
||||||
|
);
|
||||||
|
println!("-----------------------------------");
|
||||||
|
for index in top_k_indices {
|
||||||
|
println!("Page: {:?}", index + 1);
|
||||||
|
}
|
||||||
|
println!("-----------------------------------");
|
||||||
|
Ok(())
|
||||||
|
}
|
42
candle-transformers/src/models/colpali.rs
Normal file
42
candle-transformers/src/models/colpali.rs
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
use candle::{Module, Result, Tensor};
|
||||||
|
use candle_nn::VarBuilder;
|
||||||
|
|
||||||
|
use super::paligemma;
|
||||||
|
use candle_nn::{linear, Linear};
|
||||||
|
|
||||||
|
pub struct Model {
|
||||||
|
pub model: paligemma::Model,
|
||||||
|
pub custom_text_projection: Linear,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Model {
|
||||||
|
pub fn new(config: &paligemma::Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let model = paligemma::Model::new(config, vb.pp("model"))?;
|
||||||
|
let custom_text_projection = linear(
|
||||||
|
config.text_config.hidden_size,
|
||||||
|
128,
|
||||||
|
vb.pp("custom_text_proj"),
|
||||||
|
)?;
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
model,
|
||||||
|
custom_text_projection,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward_images(&mut self, pixel_values: &Tensor, input_ids: &Tensor) -> Result<Tensor> {
|
||||||
|
let outputs = self
|
||||||
|
.model
|
||||||
|
.setup_without_projection(pixel_values, input_ids)?;
|
||||||
|
let outputs = self.custom_text_projection.forward(&outputs)?;
|
||||||
|
let outputs = outputs.broadcast_div(&outputs.sqr()?.sum_keepdim(2)?.sqrt()?)?;
|
||||||
|
Ok(outputs)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward_text(&mut self, input_ids: &Tensor) -> Result<Tensor> {
|
||||||
|
let outputs = self.model.forward_without_projection(input_ids)?;
|
||||||
|
let outputs = self.custom_text_projection.forward(&outputs)?;
|
||||||
|
let outputs = outputs.broadcast_div(&outputs.sqr()?.sum_keepdim(2)?.sqrt()?)?;
|
||||||
|
Ok(outputs)
|
||||||
|
}
|
||||||
|
}
|
@ -403,7 +403,6 @@ impl Model {
|
|||||||
.apply(&self.norm)?
|
.apply(&self.norm)?
|
||||||
.apply(&self.lm_head)
|
.apply(&self.lm_head)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn forward_embeds(
|
pub fn forward_embeds(
|
||||||
&mut self,
|
&mut self,
|
||||||
xs: &Tensor,
|
xs: &Tensor,
|
||||||
@ -420,6 +419,21 @@ impl Model {
|
|||||||
.apply(&self.lm_head)
|
.apply(&self.lm_head)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Forward the model and return the hidden states without the lm_head
|
||||||
|
pub fn forward_embeds_without_projection(
|
||||||
|
&mut self,
|
||||||
|
xs: &Tensor,
|
||||||
|
attn_mask: Option<&Tensor>,
|
||||||
|
seqlen_offset: usize,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
let (_, _, _) = xs.dims3()?;
|
||||||
|
let mut xs = (xs * (self.hidden_size as f64).sqrt())?;
|
||||||
|
for layer in self.layers.iter_mut() {
|
||||||
|
xs = layer.forward(&xs, attn_mask, seqlen_offset)?
|
||||||
|
}
|
||||||
|
Ok(xs)
|
||||||
|
}
|
||||||
|
|
||||||
pub fn clear_kv_cache(&mut self) {
|
pub fn clear_kv_cache(&mut self) {
|
||||||
for layer in self.layers.iter_mut() {
|
for layer in self.layers.iter_mut() {
|
||||||
layer.clear_kv_cache()
|
layer.clear_kv_cache()
|
||||||
|
@ -7,6 +7,7 @@ pub mod blip_text;
|
|||||||
pub mod chatglm;
|
pub mod chatglm;
|
||||||
pub mod clip;
|
pub mod clip;
|
||||||
pub mod codegeex4_9b;
|
pub mod codegeex4_9b;
|
||||||
|
pub mod colpali;
|
||||||
pub mod convmixer;
|
pub mod convmixer;
|
||||||
pub mod convnext;
|
pub mod convnext;
|
||||||
pub mod dac;
|
pub mod dac;
|
||||||
|
@ -33,6 +33,29 @@ impl Config {
|
|||||||
projection_dim: 2048,
|
projection_dim: 2048,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn paligemma_3b_448() -> Self {
|
||||||
|
Self {
|
||||||
|
vision_config: siglip::VisionConfig::paligemma_3b_448(),
|
||||||
|
text_config: gemma::Config {
|
||||||
|
hidden_size: 2048,
|
||||||
|
intermediate_size: 16384,
|
||||||
|
num_attention_heads: 8,
|
||||||
|
num_hidden_layers: 18,
|
||||||
|
num_key_value_heads: 1,
|
||||||
|
// Default values.
|
||||||
|
rope_theta: 10000.,
|
||||||
|
head_dim: 256,
|
||||||
|
hidden_act: Some(candle_nn::Activation::GeluPytorchTanh),
|
||||||
|
hidden_activation: None,
|
||||||
|
attention_bias: false,
|
||||||
|
max_position_embeddings: 8192,
|
||||||
|
rms_norm_eps: 1e-6,
|
||||||
|
vocab_size: 257216,
|
||||||
|
},
|
||||||
|
projection_dim: 2048,
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
@ -102,6 +125,28 @@ impl Model {
|
|||||||
self.language_model.forward(input_ids, pos)
|
self.language_model.forward(input_ids, pos)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn forward_without_projection(&mut self, input_ids: &Tensor) -> Result<Tensor> {
|
||||||
|
self.clear_kv_cache();
|
||||||
|
let input_embeds = self.language_model.embed_tokens().forward(input_ids)?;
|
||||||
|
self.language_model
|
||||||
|
.forward_embeds_without_projection(&input_embeds, None, 0)
|
||||||
|
}
|
||||||
|
pub fn setup_without_projection(
|
||||||
|
&mut self,
|
||||||
|
pixel_values: &Tensor,
|
||||||
|
input_ids: &Tensor,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
self.clear_kv_cache();
|
||||||
|
let image_features = self
|
||||||
|
.vision_tower
|
||||||
|
.forward(pixel_values)?
|
||||||
|
.apply(&self.multi_modal_projector)?;
|
||||||
|
let image_features = crate::models::clip::div_l2_norm(&image_features)?;
|
||||||
|
let text_features = self.language_model.embed_tokens().forward(input_ids)?;
|
||||||
|
let input_embeds = Tensor::cat(&[image_features, text_features], 1)?;
|
||||||
|
self.language_model
|
||||||
|
.forward_embeds_without_projection(&input_embeds, None, 0)
|
||||||
|
}
|
||||||
pub fn clear_kv_cache(&mut self) {
|
pub fn clear_kv_cache(&mut self) {
|
||||||
self.pos = 0;
|
self.pos = 0;
|
||||||
self.language_model.clear_kv_cache()
|
self.language_model.clear_kv_cache()
|
||||||
|
Reference in New Issue
Block a user