#[cfg(feature = "mkl")] extern crate intel_mkl_src; #[cfg(feature = "accelerate")] extern crate accelerate_src; use anyhow::Error as E; use clap::Parser; use candle::{DType, Device, Result, Tensor}; use candle_examples::token_output_stream::TokenOutputStream; use candle_nn::VarBuilder; use candle_transformers::models::blip; use tokenizers::Tokenizer; #[derive(Parser)] struct Args { #[arg(long)] model: Option, #[arg(long)] tokenizer: Option, #[arg(long)] image: String, /// Run on CPU rather than on GPU. #[arg(long)] cpu: bool, } const SEP_TOKEN_ID: u32 = 102; /// Loads an image from disk using the image crate, this returns a tensor with shape /// (3, 384, 384). OpenAI normalization is applied. pub fn load_image>(p: P) -> Result { let img = image::io::Reader::open(p)? .decode() .map_err(candle::Error::wrap)? .resize_to_fill(384, 384, image::imageops::FilterType::Triangle); let img = img.to_rgb8(); let data = img.into_raw(); let data = Tensor::from_vec(data, (384, 384, 3), &Device::Cpu)?.permute((2, 0, 1))?; let mean = Tensor::new(&[0.48145466f32, 0.4578275, 0.40821073], &Device::Cpu)?.reshape((3, 1, 1))?; let std = Tensor::new(&[0.26862954f32, 0.261_302_6, 0.275_777_1], &Device::Cpu)? .reshape((3, 1, 1))?; (data.to_dtype(candle::DType::F32)? / 255.)? .broadcast_sub(&mean)? .broadcast_div(&std) } pub fn main() -> anyhow::Result<()> { let args = Args::parse(); let device = candle_examples::device(args.cpu)?; let image = load_image(args.image)?.to_device(&device)?; println!("loaded image {image:?}"); let model_file = match args.model { None => { let api = hf_hub::api::sync::Api::new()?; let api = api.repo(hf_hub::Repo::with_revision( "Salesforce/blip-image-captioning-large".to_string(), hf_hub::RepoType::Model, "refs/pr/18".to_string(), )); api.get("model.safetensors")? } Some(model) => model.into(), }; let tokenizer = match args.tokenizer { None => { let api = hf_hub::api::sync::Api::new()?; let api = api.model("Salesforce/blip-image-captioning-large".to_string()); api.get("tokenizer.json")? } Some(file) => file.into(), }; let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?; let mut tokenizer = TokenOutputStream::new(tokenizer); let mut logits_processor = candle_transformers::generation::LogitsProcessor::new(1337, None, None); let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? }; let config = blip::Config::image_captioning_large(); let model = blip::BlipForConditionalGeneration::new(&config, vb)?; let vision_model = model.vision_model(); let text_decoder = model.text_decoder(); println!("model built"); // TODO: Maybe add support for the conditional prompt. let image_embeds = image.unsqueeze(0)?.apply(vision_model)?; let mut token_ids = vec![30522u32]; for _index in 0..1000 { let input_ids = Tensor::new(token_ids.as_slice(), &device)?.broadcast_left(1)?; let logits = text_decoder.forward(&input_ids, &image_embeds)?; let logits = logits.squeeze(0)?; let logits = logits.get(logits.dim(0)? - 1)?; let token = logits_processor.sample(&logits)?; if token == SEP_TOKEN_ID { break; } token_ids.push(token); if let Some(t) = tokenizer.next_token(token)? { use std::io::Write; print!("{t}"); std::io::stdout().flush()?; } } if let Some(rest) = tokenizer.decode_rest().map_err(E::msg)? { print!("{rest}"); } Ok(()) }