Add PaliGemma. (#2519)

* Add PaliGemma.

* PaliGemma inference loop.

* Running PaliGemma example.

* Tweak the prompt.
This commit is contained in:
Laurent Mazare
2024-09-29 19:56:56 +02:00
committed by GitHub
parent 0ebb38813b
commit 2f49e1b534
5 changed files with 434 additions and 0 deletions

View File

@ -362,6 +362,10 @@ impl Model {
})
}
pub fn embed_tokens(&self) -> &candle_nn::Embedding {
&self.embed_tokens
}
fn prepare_decoder_attention_mask(
&self,
b_size: usize,
@ -400,6 +404,22 @@ impl Model {
.apply(&self.lm_head)
}
pub fn forward_embeds(
&mut self,
xs: &Tensor,
attn_mask: Option<&Tensor>,
seqlen_offset: usize,
) -> Result<Tensor> {
let (_, seq_len, _) = xs.dims3()?;
let mut xs = (xs * (self.hidden_size as f64).sqrt())?;
for layer in self.layers.iter_mut() {
xs = layer.forward(&xs, attn_mask, seqlen_offset)?
}
xs.narrow(1, seq_len - 1, 1)?
.apply(&self.norm)?
.apply(&self.lm_head)
}
pub fn clear_kv_cache(&mut self) {
for layer in self.layers.iter_mut() {
layer.clear_kv_cache()