mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Blip fixes (#1145)
* Some fixes for the blip example. * Stop generating on sep tokens. * Clippy fixes. * rustfmt.
This commit is contained in:
@ -306,31 +306,4 @@ impl BlipForConditionalGeneration {
|
||||
pub fn text_decoder(&self) -> &blip_text::TextLMHeadModel {
|
||||
&self.text_decoder
|
||||
}
|
||||
|
||||
pub fn generate(
|
||||
&self,
|
||||
pixel_values: &Tensor,
|
||||
input_ids: Option<&Tensor>,
|
||||
attention_mask: Option<&Tensor>,
|
||||
) -> Result<Tensor> {
|
||||
let image_embeds = pixel_values.apply(&self.vision_model)?;
|
||||
let b_size = image_embeds.dim(0)?;
|
||||
if b_size > 1 {
|
||||
candle::bail!("only a batch size of 1 is supported")
|
||||
}
|
||||
let mut logits_processor = crate::generation::LogitsProcessor::new(1337, None, None);
|
||||
let mut token_ids = vec![30522u32];
|
||||
for i in 0..1000 {
|
||||
let input_ids =
|
||||
Tensor::new(token_ids.as_slice(), pixel_values.device())?.broadcast_left(b_size)?;
|
||||
let logits = self.text_decoder.forward(&input_ids, &image_embeds)?;
|
||||
println!("{logits:?}");
|
||||
let logits = logits.squeeze(0)?;
|
||||
let logits = logits.get(logits.dim(0)? - 1)?;
|
||||
let token = logits_processor.sample(&logits)?;
|
||||
println!("{token}");
|
||||
token_ids.push(token)
|
||||
}
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user