Files
candle/candle-transformers/src/models/colpali.rs
zachcp f689ce5d39 Documentation Pass for Models (#2617)
* links in chinese_clip

* links for clip model

* add mod docs for flux and llava

* module doc for MMDIT and MIMI

* add docs for a few more modesl

* mod docs for bert naser and beit

* add module docs for convmixer colpali codegeex and chatglm

* add another series of moddocs

* add  fastvit-llama2_c

* module docs mamba -> mobileone

* module docs from moondream-phi3

* mod docs for quantized and qwen

* update to yi

* fix long names

* Update llama2_c.rs

* Update llama2_c_weights.rs

* Fix the link for mimi + tweaks

---------

Co-authored-by: Laurent Mazare <laurent.mazare@gmail.com>
2024-11-15 08:30:15 +01:00

48 lines
1.5 KiB
Rust

//! Colpali Model for text/image similarity scoring.
//!
//! Colpali combines a vision encoder with an efficient LM for retrieving content.
//!
use candle::{Module, Result, Tensor};
use candle_nn::VarBuilder;
use super::paligemma;
use candle_nn::{linear, Linear};
pub struct Model {
pub model: paligemma::Model,
pub custom_text_projection: Linear,
}
impl Model {
pub fn new(config: &paligemma::Config, vb: VarBuilder) -> Result<Self> {
let model = paligemma::Model::new(config, vb.pp("model"))?;
let custom_text_projection = linear(
config.text_config.hidden_size,
128,
vb.pp("custom_text_proj"),
)?;
Ok(Self {
model,
custom_text_projection,
})
}
pub fn forward_images(&mut self, pixel_values: &Tensor, input_ids: &Tensor) -> Result<Tensor> {
let outputs = self
.model
.setup_without_projection(pixel_values, input_ids)?;
let outputs = self.custom_text_projection.forward(&outputs)?;
let outputs = outputs.broadcast_div(&outputs.sqr()?.sum_keepdim(2)?.sqrt()?)?;
Ok(outputs)
}
pub fn forward_text(&mut self, input_ids: &Tensor) -> Result<Tensor> {
let outputs = self.model.forward_without_projection(input_ids)?;
let outputs = self.custom_text_projection.forward(&outputs)?;
let outputs = outputs.broadcast_div(&outputs.sqr()?.sum_keepdim(2)?.sqrt()?)?;
Ok(outputs)
}
}