mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
@ -48,6 +48,7 @@ pub struct Model {
|
||||
pub vision_tower: vision_model::Model,
|
||||
pub patch_size: usize,
|
||||
pub dtype: candle::DType,
|
||||
pub pos: usize,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
@ -67,6 +68,31 @@ impl Model {
|
||||
vision_tower,
|
||||
patch_size: cfg.vision_config.patch_size,
|
||||
dtype: vb.dtype(),
|
||||
pos: 0,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn clear_kv_cache(&mut self) {
|
||||
self.language_model.clear_kv_cache();
|
||||
self.pos = 0;
|
||||
}
|
||||
|
||||
pub fn encode_image(&self, image: &Tensor) -> Result<Tensor> {
|
||||
let image_embeds = self.vision_tower.forward(image)?;
|
||||
self.multi_modal_projector.forward(&image_embeds)
|
||||
}
|
||||
|
||||
pub fn lm_forward(&mut self, input_ids: &Tensor) -> Result<Tensor> {
|
||||
let (_, seq_len) = input_ids.dims2()?;
|
||||
let logits = self.language_model.forward(input_ids, self.pos)?;
|
||||
self.pos += seq_len;
|
||||
Ok(logits)
|
||||
}
|
||||
|
||||
pub fn lm_forward_embeds(&mut self, xs: &Tensor) -> Result<Tensor> {
|
||||
let (_, seq_len, _) = xs.dims3()?;
|
||||
let logits = self.language_model.forward_embeds(xs, None, self.pos)?;
|
||||
self.pos += seq_len;
|
||||
Ok(logits)
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user