Add ColPali (#2524)

* add colpali

* cleanup

* fix clippy
This commit is contained in:
Akshay Ballal
2024-10-01 11:48:39 +02:00
committed by GitHub
parent 6110ad8d4f
commit 888d886dd8
7 changed files with 394 additions and 1 deletions

View File

@ -0,0 +1,42 @@
use candle::{Module, Result, Tensor};
use candle_nn::VarBuilder;
use super::paligemma;
use candle_nn::{linear, Linear};
pub struct Model {
pub model: paligemma::Model,
pub custom_text_projection: Linear,
}
impl Model {
pub fn new(config: &paligemma::Config, vb: VarBuilder) -> Result<Self> {
let model = paligemma::Model::new(config, vb.pp("model"))?;
let custom_text_projection = linear(
config.text_config.hidden_size,
128,
vb.pp("custom_text_proj"),
)?;
Ok(Self {
model,
custom_text_projection,
})
}
pub fn forward_images(&mut self, pixel_values: &Tensor, input_ids: &Tensor) -> Result<Tensor> {
let outputs = self
.model
.setup_without_projection(pixel_values, input_ids)?;
let outputs = self.custom_text_projection.forward(&outputs)?;
let outputs = outputs.broadcast_div(&outputs.sqr()?.sum_keepdim(2)?.sqrt()?)?;
Ok(outputs)
}
pub fn forward_text(&mut self, input_ids: &Tensor) -> Result<Tensor> {
let outputs = self.model.forward_without_projection(input_ids)?;
let outputs = self.custom_text_projection.forward(&outputs)?;
let outputs = outputs.broadcast_div(&outputs.sqr()?.sum_keepdim(2)?.sqrt()?)?;
Ok(outputs)
}
}

View File

@ -403,7 +403,6 @@ impl Model {
.apply(&self.norm)?
.apply(&self.lm_head)
}
pub fn forward_embeds(
&mut self,
xs: &Tensor,
@ -420,6 +419,21 @@ impl Model {
.apply(&self.lm_head)
}
// Forward the model and return the hidden states without the lm_head
pub fn forward_embeds_without_projection(
&mut self,
xs: &Tensor,
attn_mask: Option<&Tensor>,
seqlen_offset: usize,
) -> Result<Tensor> {
let (_, _, _) = xs.dims3()?;
let mut xs = (xs * (self.hidden_size as f64).sqrt())?;
for layer in self.layers.iter_mut() {
xs = layer.forward(&xs, attn_mask, seqlen_offset)?
}
Ok(xs)
}
pub fn clear_kv_cache(&mut self) {
for layer in self.layers.iter_mut() {
layer.clear_kv_cache()

View File

@ -7,6 +7,7 @@ pub mod blip_text;
pub mod chatglm;
pub mod clip;
pub mod codegeex4_9b;
pub mod colpali;
pub mod convmixer;
pub mod convnext;
pub mod dac;

View File

@ -33,6 +33,29 @@ impl Config {
projection_dim: 2048,
}
}
pub fn paligemma_3b_448() -> Self {
Self {
vision_config: siglip::VisionConfig::paligemma_3b_448(),
text_config: gemma::Config {
hidden_size: 2048,
intermediate_size: 16384,
num_attention_heads: 8,
num_hidden_layers: 18,
num_key_value_heads: 1,
// Default values.
rope_theta: 10000.,
head_dim: 256,
hidden_act: Some(candle_nn::Activation::GeluPytorchTanh),
hidden_activation: None,
attention_bias: false,
max_position_embeddings: 8192,
rms_norm_eps: 1e-6,
vocab_size: 257216,
},
projection_dim: 2048,
}
}
}
#[derive(Clone, Debug)]
@ -102,6 +125,28 @@ impl Model {
self.language_model.forward(input_ids, pos)
}
pub fn forward_without_projection(&mut self, input_ids: &Tensor) -> Result<Tensor> {
self.clear_kv_cache();
let input_embeds = self.language_model.embed_tokens().forward(input_ids)?;
self.language_model
.forward_embeds_without_projection(&input_embeds, None, 0)
}
pub fn setup_without_projection(
&mut self,
pixel_values: &Tensor,
input_ids: &Tensor,
) -> Result<Tensor> {
self.clear_kv_cache();
let image_features = self
.vision_tower
.forward(pixel_values)?
.apply(&self.multi_modal_projector)?;
let image_features = crate::models::clip::div_l2_norm(&image_features)?;
let text_features = self.language_model.embed_tokens().forward(input_ids)?;
let input_embeds = Tensor::cat(&[image_features, text_features], 1)?;
self.language_model
.forward_embeds_without_projection(&input_embeds, None, 0)
}
pub fn clear_kv_cache(&mut self) {
self.pos = 0;
self.language_model.clear_kv_cache()