mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Add PaliGemma. (#2519)
* Add PaliGemma. * PaliGemma inference loop. * Running PaliGemma example. * Tweak the prompt.
This commit is contained in:
@ -362,6 +362,10 @@ impl Model {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn embed_tokens(&self) -> &candle_nn::Embedding {
|
||||
&self.embed_tokens
|
||||
}
|
||||
|
||||
fn prepare_decoder_attention_mask(
|
||||
&self,
|
||||
b_size: usize,
|
||||
@ -400,6 +404,22 @@ impl Model {
|
||||
.apply(&self.lm_head)
|
||||
}
|
||||
|
||||
pub fn forward_embeds(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
attn_mask: Option<&Tensor>,
|
||||
seqlen_offset: usize,
|
||||
) -> Result<Tensor> {
|
||||
let (_, seq_len, _) = xs.dims3()?;
|
||||
let mut xs = (xs * (self.hidden_size as f64).sqrt())?;
|
||||
for layer in self.layers.iter_mut() {
|
||||
xs = layer.forward(&xs, attn_mask, seqlen_offset)?
|
||||
}
|
||||
xs.narrow(1, seq_len - 1, 1)?
|
||||
.apply(&self.norm)?
|
||||
.apply(&self.lm_head)
|
||||
}
|
||||
|
||||
pub fn clear_kv_cache(&mut self) {
|
||||
for layer in self.layers.iter_mut() {
|
||||
layer.clear_kv_cache()
|
||||
|
Reference in New Issue
Block a user