mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
42
candle-transformers/src/models/colpali.rs
Normal file
42
candle-transformers/src/models/colpali.rs
Normal file
@ -0,0 +1,42 @@
|
||||
use candle::{Module, Result, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
|
||||
use super::paligemma;
|
||||
use candle_nn::{linear, Linear};
|
||||
|
||||
pub struct Model {
|
||||
pub model: paligemma::Model,
|
||||
pub custom_text_projection: Linear,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
pub fn new(config: &paligemma::Config, vb: VarBuilder) -> Result<Self> {
|
||||
let model = paligemma::Model::new(config, vb.pp("model"))?;
|
||||
let custom_text_projection = linear(
|
||||
config.text_config.hidden_size,
|
||||
128,
|
||||
vb.pp("custom_text_proj"),
|
||||
)?;
|
||||
|
||||
Ok(Self {
|
||||
model,
|
||||
custom_text_projection,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward_images(&mut self, pixel_values: &Tensor, input_ids: &Tensor) -> Result<Tensor> {
|
||||
let outputs = self
|
||||
.model
|
||||
.setup_without_projection(pixel_values, input_ids)?;
|
||||
let outputs = self.custom_text_projection.forward(&outputs)?;
|
||||
let outputs = outputs.broadcast_div(&outputs.sqr()?.sum_keepdim(2)?.sqrt()?)?;
|
||||
Ok(outputs)
|
||||
}
|
||||
|
||||
pub fn forward_text(&mut self, input_ids: &Tensor) -> Result<Tensor> {
|
||||
let outputs = self.model.forward_without_projection(input_ids)?;
|
||||
let outputs = self.custom_text_projection.forward(&outputs)?;
|
||||
let outputs = outputs.broadcast_div(&outputs.sqr()?.sum_keepdim(2)?.sqrt()?)?;
|
||||
Ok(outputs)
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user