Add ColPali (#2524)

* add colpali

* cleanup

* fix clippy
This commit is contained in:
Akshay Ballal
2024-10-01 11:48:39 +02:00
committed by GitHub
parent 6110ad8d4f
commit 888d886dd8
7 changed files with 394 additions and 1 deletions

View File

@ -0,0 +1,42 @@
use candle::{Module, Result, Tensor};
use candle_nn::VarBuilder;
use super::paligemma;
use candle_nn::{linear, Linear};
pub struct Model {
pub model: paligemma::Model,
pub custom_text_projection: Linear,
}
impl Model {
pub fn new(config: &paligemma::Config, vb: VarBuilder) -> Result<Self> {
let model = paligemma::Model::new(config, vb.pp("model"))?;
let custom_text_projection = linear(
config.text_config.hidden_size,
128,
vb.pp("custom_text_proj"),
)?;
Ok(Self {
model,
custom_text_projection,
})
}
pub fn forward_images(&mut self, pixel_values: &Tensor, input_ids: &Tensor) -> Result<Tensor> {
let outputs = self
.model
.setup_without_projection(pixel_values, input_ids)?;
let outputs = self.custom_text_projection.forward(&outputs)?;
let outputs = outputs.broadcast_div(&outputs.sqr()?.sum_keepdim(2)?.sqrt()?)?;
Ok(outputs)
}
pub fn forward_text(&mut self, input_ids: &Tensor) -> Result<Tensor> {
let outputs = self.model.forward_without_projection(input_ids)?;
let outputs = self.custom_text_projection.forward(&outputs)?;
let outputs = outputs.broadcast_div(&outputs.sqr()?.sum_keepdim(2)?.sqrt()?)?;
Ok(outputs)
}
}