From dfe9a006834938a7d4dde6a6e3b81ed6e595bf99 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Mon, 30 Sep 2024 21:23:54 +0200 Subject: [PATCH] Pixtral polishing. (#2522) * Pixtral polishing. * Clippy fix. --- candle-examples/examples/pixtral/main.rs | 15 +++-------- .../src/models/pixtral/llava.rs | 26 +++++++++++++++++++ 2 files changed, 29 insertions(+), 12 deletions(-) diff --git a/candle-examples/examples/pixtral/main.rs b/candle-examples/examples/pixtral/main.rs index 8e48b60b..79f43868 100644 --- a/candle-examples/examples/pixtral/main.rs +++ b/candle-examples/examples/pixtral/main.rs @@ -73,22 +73,18 @@ impl TextGeneration { let img_break = get_token("[IMG_BREAK]")?; let img_end = get_token("[IMG_END]")?; let start_gen = std::time::Instant::now(); - let mut pos = 0; for index in 0..sample_len { let logits = if index > 0 { let context_size = if index > 0 { 1 } else { tokens.len() }; let start_pos = tokens.len().saturating_sub(context_size); let ctxt = &tokens[start_pos..]; let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?; - let logits = self.model.language_model.forward(&input, pos)?; - pos += context_size; - logits + self.model.lm_forward(&input)? } else { let (_b, _c, h, w) = self.image.dims4()?; let h = h / self.model.patch_size; let w = w / self.model.patch_size; - let image_embeds = self.model.vision_tower.forward(&self.image)?; - let image_embeds = self.model.multi_modal_projector.forward(&image_embeds)?; + let image_embeds = self.model.encode_image(&self.image)?; println!("generated image embeddings {image_embeds:?}"); let image_embeds = image_embeds.to_dtype(self.model.dtype)?; for &t in tokens.iter() { @@ -124,12 +120,7 @@ impl TextGeneration { input_embeds.push(end_embeds); let input_embeds = Tensor::cat(&input_embeds, 1)?; - let logits = self - .model - .language_model - .forward_embeds(&input_embeds, None, pos)?; - pos += input_embeds.dim(1)?; - logits + self.model.lm_forward_embeds(&input_embeds)? }; let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?; let logits = if self.repeat_penalty == 1. { diff --git a/candle-transformers/src/models/pixtral/llava.rs b/candle-transformers/src/models/pixtral/llava.rs index 33e0aca0..4aff26a7 100644 --- a/candle-transformers/src/models/pixtral/llava.rs +++ b/candle-transformers/src/models/pixtral/llava.rs @@ -48,6 +48,7 @@ pub struct Model { pub vision_tower: vision_model::Model, pub patch_size: usize, pub dtype: candle::DType, + pub pos: usize, } impl Model { @@ -67,6 +68,31 @@ impl Model { vision_tower, patch_size: cfg.vision_config.patch_size, dtype: vb.dtype(), + pos: 0, }) } + + pub fn clear_kv_cache(&mut self) { + self.language_model.clear_kv_cache(); + self.pos = 0; + } + + pub fn encode_image(&self, image: &Tensor) -> Result { + let image_embeds = self.vision_tower.forward(image)?; + self.multi_modal_projector.forward(&image_embeds) + } + + pub fn lm_forward(&mut self, input_ids: &Tensor) -> Result { + let (_, seq_len) = input_ids.dims2()?; + let logits = self.language_model.forward(input_ids, self.pos)?; + self.pos += seq_len; + Ok(logits) + } + + pub fn lm_forward_embeds(&mut self, xs: &Tensor) -> Result { + let (_, seq_len, _) = xs.dims3()?; + let logits = self.language_model.forward_embeds(xs, None, self.pos)?; + self.pos += seq_len; + Ok(logits) + } }