diff --git a/candle-examples/examples/paligemma/README.md b/candle-examples/examples/paligemma/README.md new file mode 100644 index 00000000..56ae061e --- /dev/null +++ b/candle-examples/examples/paligemma/README.md @@ -0,0 +1,28 @@ +# PaliGemma + +[HuggingFace Model Card](https://huggingface.co/google/paligemma-3b-pt-224) - +[Model Page](https://ai.google.dev/gemma/docs/paligemma) + +```bash +cargo run --features cuda --release --example paligemma -- \ + --prompt "caption fr" --image candle-examples/examples/yolo-v8/assets/bike.jpg +``` + +``` +loaded image with shape Tensor[dims 1, 3, 224, 224; bf16, cuda:0] +loaded the model in 1.267744448s +caption fr. Un groupe de cyclistes qui sont dans la rue. +13 tokens generated (56.52 token/s) +``` + +```bash +cargo run --features cuda --release --example paligemma -- \ + --prompt "caption fr" --image candle-examples/examples/flux/assets/flux-robot.jpg +``` + +``` +loaded image with shape Tensor[dims 1, 3, 224, 224; bf16, cuda:0] +loaded the model in 1.271492621s +caption fr une image d' un robot sur la plage avec le mot rouillé +15 tokens generated (62.78 token/s) +``` diff --git a/candle-examples/examples/paligemma/main.rs b/candle-examples/examples/paligemma/main.rs new file mode 100644 index 00000000..9ce5011b --- /dev/null +++ b/candle-examples/examples/paligemma/main.rs @@ -0,0 +1,276 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use anyhow::{Error as E, Result}; +use clap::Parser; + +use candle_transformers::models::paligemma::{Config, Model}; + +use candle::{DType, Device, Tensor}; +use candle_examples::token_output_stream::TokenOutputStream; +use candle_nn::VarBuilder; +use candle_transformers::generation::LogitsProcessor; +use hf_hub::{api::sync::Api, Repo, RepoType}; +use tokenizers::Tokenizer; + +struct TextGeneration { + model: Model, + image: Tensor, + device: Device, + tokenizer: TokenOutputStream, + logits_processor: LogitsProcessor, + repeat_penalty: f32, + repeat_last_n: usize, +} + +impl TextGeneration { + #[allow(clippy::too_many_arguments)] + fn new( + model: Model, + image: Tensor, + tokenizer: Tokenizer, + seed: u64, + temp: Option, + top_p: Option, + repeat_penalty: f32, + repeat_last_n: usize, + device: &Device, + ) -> Self { + let logits_processor = LogitsProcessor::new(seed, temp, top_p); + Self { + model, + image, + tokenizer: TokenOutputStream::new(tokenizer), + logits_processor, + repeat_penalty, + repeat_last_n, + device: device.clone(), + } + } + + fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> { + use std::io::Write; + self.tokenizer.clear(); + let mut tokens = self + .tokenizer + .tokenizer() + .encode(prompt, true) + .map_err(E::msg)? + .get_ids() + .to_vec(); + for &t in tokens.iter() { + if let Some(t) = self.tokenizer.next_token(t)? { + print!("{t}") + } + } + std::io::stdout().flush()?; + + let mut generated_tokens = 0usize; + let eos_token = match self.tokenizer.get_token("") { + Some(token) => token, + None => anyhow::bail!("cannot find the token"), + }; + let start_gen = std::time::Instant::now(); + for index in 0..sample_len { + let context_size = if index > 0 { 1 } else { tokens.len() }; + let start_pos = tokens.len().saturating_sub(context_size); + let ctxt = &tokens[start_pos..]; + let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?; + let logits = if index > 0 { + self.model.forward(&input)? + } else { + self.model.setup(&self.image, &input)? + }; + let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?; + let logits = if self.repeat_penalty == 1. { + logits + } else { + let start_at = tokens.len().saturating_sub(self.repeat_last_n); + candle_transformers::utils::apply_repeat_penalty( + &logits, + self.repeat_penalty, + &tokens[start_at..], + )? + }; + + let next_token = self.logits_processor.sample(&logits)?; + tokens.push(next_token); + generated_tokens += 1; + if next_token == eos_token { + break; + } + if let Some(t) = self.tokenizer.next_token(next_token)? { + print!("{t}"); + std::io::stdout().flush()?; + } + } + let dt = start_gen.elapsed(); + if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? { + print!("{rest}"); + } + std::io::stdout().flush()?; + println!( + "\n{generated_tokens} tokens generated ({:.2} token/s)", + generated_tokens as f64 / dt.as_secs_f64(), + ); + Ok(()) + } +} + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, + + /// Enable tracing (generates a trace-timestamp.json file). + #[arg(long)] + tracing: bool, + + #[arg(long)] + prompt: String, + + /// The temperature used to generate samples. + #[arg(long)] + temperature: Option, + + /// Nucleus sampling probability cutoff. + #[arg(long)] + top_p: Option, + + /// The seed to use when generating random samples. + #[arg(long, default_value_t = 299792458)] + seed: u64, + + /// The length of the sample to generate (in tokens). + #[arg(long, short = 'n', default_value_t = 10000)] + sample_len: usize, + + #[arg(long)] + model_id: Option, + + #[arg(long, default_value = "main")] + revision: String, + + #[arg(long)] + tokenizer_file: Option, + + #[arg(long)] + weight_files: Option, + + /// Penalty to be applied for repeating tokens, 1. means no penalty. + #[arg(long, default_value_t = 1.1)] + repeat_penalty: f32, + + /// The context size to consider for the repeat penalty. + #[arg(long, default_value_t = 64)] + repeat_last_n: usize, + + #[arg(long)] + image: String, +} + +fn load_image>(path: T, image_size: usize) -> anyhow::Result { + let img = image::ImageReader::open(path)?.decode()?; + let (height, width) = (image_size, image_size); + let img = img.resize_to_fill( + width as u32, + height as u32, + image::imageops::FilterType::Triangle, + ); + let img = img.to_rgb8(); + let img = img.into_raw(); + let img = Tensor::from_vec(img, (height, width, 3), &Device::Cpu)? + .permute((2, 0, 1))? + .to_dtype(DType::F32)? + .affine(2. / 255., -1.)?; + Ok(img) +} + +fn main() -> Result<()> { + use tracing_chrome::ChromeLayerBuilder; + use tracing_subscriber::prelude::*; + + let args = Args::parse(); + let _guard = if args.tracing { + let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); + tracing_subscriber::registry().with(chrome_layer).init(); + Some(guard) + } else { + None + }; + println!( + "avx: {}, neon: {}, simd128: {}, f16c: {}", + candle::utils::with_avx(), + candle::utils::with_neon(), + candle::utils::with_simd128(), + candle::utils::with_f16c() + ); + println!( + "temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}", + args.temperature.unwrap_or(0.), + args.repeat_penalty, + args.repeat_last_n + ); + + let start = std::time::Instant::now(); + let api = Api::new()?; + let model_id = match &args.model_id { + Some(model_id) => model_id.to_string(), + None => "google/paligemma-3b-mix-224".to_string(), + }; + let repo = api.repo(Repo::with_revision( + model_id, + RepoType::Model, + args.revision, + )); + let tokenizer_filename = match args.tokenizer_file { + Some(file) => std::path::PathBuf::from(file), + None => repo.get("tokenizer.json")?, + }; + let filenames = match args.weight_files { + Some(files) => files + .split(',') + .map(std::path::PathBuf::from) + .collect::>(), + None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?, + }; + println!("retrieved the files in {:?}", start.elapsed()); + let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; + + let device = candle_examples::device(args.cpu)?; + let dtype = if device.is_cuda() { + DType::BF16 + } else { + DType::F32 + }; + let config = Config::paligemma_3b_224(); + let image = load_image(&args.image, config.vision_config.image_size)? + .to_device(&device)? + .to_dtype(dtype)? + .unsqueeze(0)?; + println!("loaded image with shape {:?}", image); + let start = std::time::Instant::now(); + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? }; + let model = Model::new(&config, vb)?; + println!("loaded the model in {:?}", start.elapsed()); + + let mut pipeline = TextGeneration::new( + model, + image, + tokenizer, + args.seed, + args.temperature, + args.top_p, + args.repeat_penalty, + args.repeat_last_n, + &device, + ); + let prompt = format!("{}\n", args.prompt); + pipeline.run(&prompt, args.sample_len)?; + Ok(()) +} diff --git a/candle-transformers/src/models/gemma.rs b/candle-transformers/src/models/gemma.rs index 1cfef59e..69e22678 100644 --- a/candle-transformers/src/models/gemma.rs +++ b/candle-transformers/src/models/gemma.rs @@ -362,6 +362,10 @@ impl Model { }) } + pub fn embed_tokens(&self) -> &candle_nn::Embedding { + &self.embed_tokens + } + fn prepare_decoder_attention_mask( &self, b_size: usize, @@ -400,6 +404,22 @@ impl Model { .apply(&self.lm_head) } + pub fn forward_embeds( + &mut self, + xs: &Tensor, + attn_mask: Option<&Tensor>, + seqlen_offset: usize, + ) -> Result { + let (_, seq_len, _) = xs.dims3()?; + let mut xs = (xs * (self.hidden_size as f64).sqrt())?; + for layer in self.layers.iter_mut() { + xs = layer.forward(&xs, attn_mask, seqlen_offset)? + } + xs.narrow(1, seq_len - 1, 1)? + .apply(&self.norm)? + .apply(&self.lm_head) + } + pub fn clear_kv_cache(&mut self) { for layer in self.layers.iter_mut() { layer.clear_kv_cache() diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index a0e7a922..bba701bd 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -46,6 +46,7 @@ pub mod moondream; pub mod mpt; pub mod olmo; pub mod openclip; +pub mod paligemma; pub mod parler_tts; pub mod persimmon; pub mod phi; diff --git a/candle-transformers/src/models/paligemma.rs b/candle-transformers/src/models/paligemma.rs new file mode 100644 index 00000000..e22ab241 --- /dev/null +++ b/candle-transformers/src/models/paligemma.rs @@ -0,0 +1,109 @@ +use crate::models::{gemma, siglip}; +use candle::{Module, Result, Tensor}; +use candle_nn::{linear, Linear, VarBuilder}; + +#[derive(serde::Deserialize, Clone, Debug)] +pub struct Config { + pub vision_config: siglip::VisionConfig, + pub text_config: gemma::Config, + pub projection_dim: usize, +} + +impl Config { + pub fn paligemma_3b_224() -> Self { + // https://huggingface.co/google/paligemma-3b-pt-224/blob/main/config.json + Self { + vision_config: siglip::VisionConfig::paligemma_3b_224(), + text_config: gemma::Config { + hidden_size: 2048, + intermediate_size: 16384, + num_attention_heads: 8, + num_hidden_layers: 18, + num_key_value_heads: 1, + vocab_size: 257216, + // Default values. + rope_theta: 10000., + head_dim: 256, + hidden_act: Some(candle_nn::Activation::GeluPytorchTanh), + hidden_activation: None, + attention_bias: false, + max_position_embeddings: 8192, + rms_norm_eps: 1e-6, + }, + projection_dim: 2048, + } + } +} + +#[derive(Clone, Debug)] +pub struct MultiModalProjector { + linear: Linear, +} + +impl MultiModalProjector { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let linear = linear( + cfg.vision_config.hidden_size, + cfg.projection_dim, + vb.pp("linear"), + )?; + Ok(Self { linear }) + } +} + +impl Module for MultiModalProjector { + fn forward(&self, xs: &Tensor) -> Result { + xs.apply(&self.linear) + } +} + +#[derive(Clone, Debug)] +pub struct Model { + pos: usize, + vision_tower: siglip::VisionModel, + multi_modal_projector: MultiModalProjector, + language_model: gemma::Model, +} + +impl Model { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result { + let vision_tower = siglip::VisionModel::new( + &cfg.vision_config, + false, + vb.pp("vision_tower.vision_model"), + )?; + let multi_modal_projector = MultiModalProjector::new(cfg, vb.pp("multi_modal_projector"))?; + let language_model = gemma::Model::new(false, &cfg.text_config, vb.pp("language_model"))?; + Ok(Self { + pos: 0, + language_model, + vision_tower, + multi_modal_projector, + }) + } + + pub fn setup(&mut self, pixel_values: &Tensor, input_ids: &Tensor) -> Result { + self.clear_kv_cache(); + let image_features = self + .vision_tower + .forward(pixel_values)? + .apply(&self.multi_modal_projector)?; + let image_features = crate::models::clip::div_l2_norm(&image_features)?; + let text_features = self.language_model.embed_tokens().forward(input_ids)?; + let input_embeds = Tensor::cat(&[image_features, text_features], 1)?; + self.pos = input_embeds.dim(1)?; + self.language_model.forward_embeds(&input_embeds, None, 0) + } + + pub fn forward(&mut self, input_ids: &Tensor) -> Result { + let pos = self.pos; + let seq_len = input_ids.dim(1)?; + self.pos = pos + seq_len; + self.language_model.forward(input_ids, pos) + } + + pub fn clear_kv_cache(&mut self) { + self.pos = 0; + self.language_model.clear_kv_cache() + } +}