From 2531b13bf85a69058e8ed1b30c683d19d036df14 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sat, 21 Oct 2023 21:34:48 +0100 Subject: [PATCH] Blip fixes (#1145) * Some fixes for the blip example. * Stop generating on sep tokens. * Clippy fixes. * rustfmt. --- candle-examples/examples/blip/main.rs | 72 ++++++++++++++++++++++++-- candle-transformers/src/models/blip.rs | 27 ---------- 2 files changed, 68 insertions(+), 31 deletions(-) diff --git a/candle-examples/examples/blip/main.rs b/candle-examples/examples/blip/main.rs index 82355778..54abab8c 100644 --- a/candle-examples/examples/blip/main.rs +++ b/candle-examples/examples/blip/main.rs @@ -4,17 +4,24 @@ extern crate intel_mkl_src; #[cfg(feature = "accelerate")] extern crate accelerate_src; +use anyhow::Error as E; use clap::Parser; -use candle::DType; +use candle::{DType, Device, Result, Tensor}; +use candle_examples::token_output_stream::TokenOutputStream; use candle_nn::VarBuilder; use candle_transformers::models::blip; +use tokenizers::Tokenizer; + #[derive(Parser)] struct Args { #[arg(long)] model: Option, + #[arg(long)] + tokenizer: Option, + #[arg(long)] image: String, @@ -23,12 +30,33 @@ struct Args { cpu: bool, } +const SEP_TOKEN_ID: u32 = 102; + +/// Loads an image from disk using the image crate, this returns a tensor with shape +/// (3, 384, 384). OpenAI normalization is applied. +pub fn load_image>(p: P) -> Result { + let img = image::io::Reader::open(p)? + .decode() + .map_err(candle::Error::wrap)? + .resize_to_fill(384, 384, image::imageops::FilterType::Triangle); + let img = img.to_rgb8(); + let data = img.into_raw(); + let data = Tensor::from_vec(data, (384, 384, 3), &Device::Cpu)?.permute((2, 0, 1))?; + let mean = + Tensor::new(&[0.48145466f32, 0.4578275, 0.40821073], &Device::Cpu)?.reshape((3, 1, 1))?; + let std = Tensor::new(&[0.26862954f32, 0.261_302_6, 0.275_777_1], &Device::Cpu)? + .reshape((3, 1, 1))?; + (data.to_dtype(candle::DType::F32)? / 255.)? + .broadcast_sub(&mean)? + .broadcast_div(&std) +} + pub fn main() -> anyhow::Result<()> { let args = Args::parse(); let device = candle_examples::device(args.cpu)?; - let image = candle_examples::imagenet::load_image224(args.image)?; + let image = load_image(args.image)?.to_device(&device)?; println!("loaded image {image:?}"); let model_file = match args.model { @@ -43,12 +71,48 @@ pub fn main() -> anyhow::Result<()> { } Some(model) => model.into(), }; + let tokenizer = match args.tokenizer { + None => { + let api = hf_hub::api::sync::Api::new()?; + let api = api.model("Salesforce/blip-image-captioning-large".to_string()); + api.get("tokenizer.json")? + } + Some(file) => file.into(), + }; + let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?; + let mut tokenizer = TokenOutputStream::new(tokenizer); + let mut logits_processor = + candle_transformers::generation::LogitsProcessor::new(1337, None, None); + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? }; let config = blip::Config::image_captioning_large(); let model = blip::BlipForConditionalGeneration::new(&config, vb)?; + let vision_model = model.vision_model(); + let text_decoder = model.text_decoder(); println!("model built"); // TODO: Maybe add support for the conditional prompt. - let out = model.generate(&image.unsqueeze(0)?, None, None)?; - println!(">>>\n{out}"); + let image_embeds = image.unsqueeze(0)?.apply(vision_model)?; + + let mut token_ids = vec![30522u32]; + for _index in 0..1000 { + let input_ids = Tensor::new(token_ids.as_slice(), &device)?.broadcast_left(1)?; + let logits = text_decoder.forward(&input_ids, &image_embeds)?; + let logits = logits.squeeze(0)?; + let logits = logits.get(logits.dim(0)? - 1)?; + let token = logits_processor.sample(&logits)?; + if token == SEP_TOKEN_ID { + break; + } + token_ids.push(token); + if let Some(t) = tokenizer.next_token(token)? { + use std::io::Write; + print!("{t}"); + std::io::stdout().flush()?; + } + } + if let Some(rest) = tokenizer.decode_rest().map_err(E::msg)? { + print!("{rest}"); + } + Ok(()) } diff --git a/candle-transformers/src/models/blip.rs b/candle-transformers/src/models/blip.rs index b2be112e..1b4f9008 100644 --- a/candle-transformers/src/models/blip.rs +++ b/candle-transformers/src/models/blip.rs @@ -306,31 +306,4 @@ impl BlipForConditionalGeneration { pub fn text_decoder(&self) -> &blip_text::TextLMHeadModel { &self.text_decoder } - - pub fn generate( - &self, - pixel_values: &Tensor, - input_ids: Option<&Tensor>, - attention_mask: Option<&Tensor>, - ) -> Result { - let image_embeds = pixel_values.apply(&self.vision_model)?; - let b_size = image_embeds.dim(0)?; - if b_size > 1 { - candle::bail!("only a batch size of 1 is supported") - } - let mut logits_processor = crate::generation::LogitsProcessor::new(1337, None, None); - let mut token_ids = vec![30522u32]; - for i in 0..1000 { - let input_ids = - Tensor::new(token_ids.as_slice(), pixel_values.device())?.broadcast_left(b_size)?; - let logits = self.text_decoder.forward(&input_ids, &image_embeds)?; - println!("{logits:?}"); - let logits = logits.squeeze(0)?; - let logits = logits.get(logits.dim(0)? - 1)?; - let token = logits_processor.sample(&logits)?; - println!("{token}"); - token_ids.push(token) - } - todo!() - } }