Blip fixes (#1145)

* Some fixes for the blip example.

* Stop generating on sep tokens.

* Clippy fixes.

* rustfmt.
This commit is contained in:
Laurent Mazare
2023-10-21 21:34:48 +01:00
committed by GitHub
parent 0d9bb4eb18
commit 2531b13bf8
2 changed files with 68 additions and 31 deletions

View File

@ -306,31 +306,4 @@ impl BlipForConditionalGeneration {
pub fn text_decoder(&self) -> &blip_text::TextLMHeadModel {
&self.text_decoder
}
pub fn generate(
&self,
pixel_values: &Tensor,
input_ids: Option<&Tensor>,
attention_mask: Option<&Tensor>,
) -> Result<Tensor> {
let image_embeds = pixel_values.apply(&self.vision_model)?;
let b_size = image_embeds.dim(0)?;
if b_size > 1 {
candle::bail!("only a batch size of 1 is supported")
}
let mut logits_processor = crate::generation::LogitsProcessor::new(1337, None, None);
let mut token_ids = vec![30522u32];
for i in 0..1000 {
let input_ids =
Tensor::new(token_ids.as_slice(), pixel_values.device())?.broadcast_left(b_size)?;
let logits = self.text_decoder.forward(&input_ids, &image_embeds)?;
println!("{logits:?}");
let logits = logits.squeeze(0)?;
let logits = logits.get(logits.dim(0)? - 1)?;
let token = logits_processor.sample(&logits)?;
println!("{token}");
token_ids.push(token)
}
todo!()
}
}