mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Add PaliGemma. (#2519)
* Add PaliGemma. * PaliGemma inference loop. * Running PaliGemma example. * Tweak the prompt.
This commit is contained in:
@ -362,6 +362,10 @@ impl Model {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn embed_tokens(&self) -> &candle_nn::Embedding {
|
||||
&self.embed_tokens
|
||||
}
|
||||
|
||||
fn prepare_decoder_attention_mask(
|
||||
&self,
|
||||
b_size: usize,
|
||||
@ -400,6 +404,22 @@ impl Model {
|
||||
.apply(&self.lm_head)
|
||||
}
|
||||
|
||||
pub fn forward_embeds(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
attn_mask: Option<&Tensor>,
|
||||
seqlen_offset: usize,
|
||||
) -> Result<Tensor> {
|
||||
let (_, seq_len, _) = xs.dims3()?;
|
||||
let mut xs = (xs * (self.hidden_size as f64).sqrt())?;
|
||||
for layer in self.layers.iter_mut() {
|
||||
xs = layer.forward(&xs, attn_mask, seqlen_offset)?
|
||||
}
|
||||
xs.narrow(1, seq_len - 1, 1)?
|
||||
.apply(&self.norm)?
|
||||
.apply(&self.lm_head)
|
||||
}
|
||||
|
||||
pub fn clear_kv_cache(&mut self) {
|
||||
for layer in self.layers.iter_mut() {
|
||||
layer.clear_kv_cache()
|
||||
|
@ -46,6 +46,7 @@ pub mod moondream;
|
||||
pub mod mpt;
|
||||
pub mod olmo;
|
||||
pub mod openclip;
|
||||
pub mod paligemma;
|
||||
pub mod parler_tts;
|
||||
pub mod persimmon;
|
||||
pub mod phi;
|
||||
|
109
candle-transformers/src/models/paligemma.rs
Normal file
109
candle-transformers/src/models/paligemma.rs
Normal file
@ -0,0 +1,109 @@
|
||||
use crate::models::{gemma, siglip};
|
||||
use candle::{Module, Result, Tensor};
|
||||
use candle_nn::{linear, Linear, VarBuilder};
|
||||
|
||||
#[derive(serde::Deserialize, Clone, Debug)]
|
||||
pub struct Config {
|
||||
pub vision_config: siglip::VisionConfig,
|
||||
pub text_config: gemma::Config,
|
||||
pub projection_dim: usize,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn paligemma_3b_224() -> Self {
|
||||
// https://huggingface.co/google/paligemma-3b-pt-224/blob/main/config.json
|
||||
Self {
|
||||
vision_config: siglip::VisionConfig::paligemma_3b_224(),
|
||||
text_config: gemma::Config {
|
||||
hidden_size: 2048,
|
||||
intermediate_size: 16384,
|
||||
num_attention_heads: 8,
|
||||
num_hidden_layers: 18,
|
||||
num_key_value_heads: 1,
|
||||
vocab_size: 257216,
|
||||
// Default values.
|
||||
rope_theta: 10000.,
|
||||
head_dim: 256,
|
||||
hidden_act: Some(candle_nn::Activation::GeluPytorchTanh),
|
||||
hidden_activation: None,
|
||||
attention_bias: false,
|
||||
max_position_embeddings: 8192,
|
||||
rms_norm_eps: 1e-6,
|
||||
},
|
||||
projection_dim: 2048,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct MultiModalProjector {
|
||||
linear: Linear,
|
||||
}
|
||||
|
||||
impl MultiModalProjector {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let linear = linear(
|
||||
cfg.vision_config.hidden_size,
|
||||
cfg.projection_dim,
|
||||
vb.pp("linear"),
|
||||
)?;
|
||||
Ok(Self { linear })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for MultiModalProjector {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
xs.apply(&self.linear)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Model {
|
||||
pos: usize,
|
||||
vision_tower: siglip::VisionModel,
|
||||
multi_modal_projector: MultiModalProjector,
|
||||
language_model: gemma::Model,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let vision_tower = siglip::VisionModel::new(
|
||||
&cfg.vision_config,
|
||||
false,
|
||||
vb.pp("vision_tower.vision_model"),
|
||||
)?;
|
||||
let multi_modal_projector = MultiModalProjector::new(cfg, vb.pp("multi_modal_projector"))?;
|
||||
let language_model = gemma::Model::new(false, &cfg.text_config, vb.pp("language_model"))?;
|
||||
Ok(Self {
|
||||
pos: 0,
|
||||
language_model,
|
||||
vision_tower,
|
||||
multi_modal_projector,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn setup(&mut self, pixel_values: &Tensor, input_ids: &Tensor) -> Result<Tensor> {
|
||||
self.clear_kv_cache();
|
||||
let image_features = self
|
||||
.vision_tower
|
||||
.forward(pixel_values)?
|
||||
.apply(&self.multi_modal_projector)?;
|
||||
let image_features = crate::models::clip::div_l2_norm(&image_features)?;
|
||||
let text_features = self.language_model.embed_tokens().forward(input_ids)?;
|
||||
let input_embeds = Tensor::cat(&[image_features, text_features], 1)?;
|
||||
self.pos = input_embeds.dim(1)?;
|
||||
self.language_model.forward_embeds(&input_embeds, None, 0)
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> {
|
||||
let pos = self.pos;
|
||||
let seq_len = input_ids.dim(1)?;
|
||||
self.pos = pos + seq_len;
|
||||
self.language_model.forward(input_ids, pos)
|
||||
}
|
||||
|
||||
pub fn clear_kv_cache(&mut self) {
|
||||
self.pos = 0;
|
||||
self.language_model.clear_kv_cache()
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user