Pixtral polishing. (#2522)

* Pixtral polishing.

* Clippy fix.
This commit is contained in:
Laurent Mazare
2024-09-30 21:23:54 +02:00
committed by GitHub
parent 683ab698de
commit dfe9a00683
2 changed files with 29 additions and 12 deletions

View File

@ -48,6 +48,7 @@ pub struct Model {
pub vision_tower: vision_model::Model,
pub patch_size: usize,
pub dtype: candle::DType,
pub pos: usize,
}
impl Model {
@ -67,6 +68,31 @@ impl Model {
vision_tower,
patch_size: cfg.vision_config.patch_size,
dtype: vb.dtype(),
pos: 0,
})
}
pub fn clear_kv_cache(&mut self) {
self.language_model.clear_kv_cache();
self.pos = 0;
}
pub fn encode_image(&self, image: &Tensor) -> Result<Tensor> {
let image_embeds = self.vision_tower.forward(image)?;
self.multi_modal_projector.forward(&image_embeds)
}
pub fn lm_forward(&mut self, input_ids: &Tensor) -> Result<Tensor> {
let (_, seq_len) = input_ids.dims2()?;
let logits = self.language_model.forward(input_ids, self.pos)?;
self.pos += seq_len;
Ok(logits)
}
pub fn lm_forward_embeds(&mut self, xs: &Tensor) -> Result<Tensor> {
let (_, seq_len, _) = xs.dims3()?;
let logits = self.language_model.forward_embeds(xs, None, self.pos)?;
self.pos += seq_len;
Ok(logits)
}
}