Pixtral polishing. (#2522)

* Pixtral polishing.

* Clippy fix.
This commit is contained in:
Laurent Mazare
2024-09-30 21:23:54 +02:00
committed by GitHub
parent 683ab698de
commit dfe9a00683
2 changed files with 29 additions and 12 deletions

View File

@ -73,22 +73,18 @@ impl TextGeneration {
let img_break = get_token("[IMG_BREAK]")?;
let img_end = get_token("[IMG_END]")?;
let start_gen = std::time::Instant::now();
let mut pos = 0;
for index in 0..sample_len {
let logits = if index > 0 {
let context_size = if index > 0 { 1 } else { tokens.len() };
let start_pos = tokens.len().saturating_sub(context_size);
let ctxt = &tokens[start_pos..];
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = self.model.language_model.forward(&input, pos)?;
pos += context_size;
logits
self.model.lm_forward(&input)?
} else {
let (_b, _c, h, w) = self.image.dims4()?;
let h = h / self.model.patch_size;
let w = w / self.model.patch_size;
let image_embeds = self.model.vision_tower.forward(&self.image)?;
let image_embeds = self.model.multi_modal_projector.forward(&image_embeds)?;
let image_embeds = self.model.encode_image(&self.image)?;
println!("generated image embeddings {image_embeds:?}");
let image_embeds = image_embeds.to_dtype(self.model.dtype)?;
for &t in tokens.iter() {
@ -124,12 +120,7 @@ impl TextGeneration {
input_embeds.push(end_embeds);
let input_embeds = Tensor::cat(&input_embeds, 1)?;
let logits = self
.model
.language_model
.forward_embeds(&input_embeds, None, pos)?;
pos += input_embeds.dim(1)?;
logits
self.model.lm_forward_embeds(&input_embeds)?
};
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
let logits = if self.repeat_penalty == 1. {