From 683ab698def755c24cec9987069d25efcf831fc4 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Mon, 30 Sep 2024 19:31:14 +0200 Subject: [PATCH] Add Pixtral. (#2521) * Add Pixtral. * More pixtral vision encoder. * Sketch a pixtral example. * Sketch a pixtral example. * Better image loading. * Support loading images embedded in safetensor files. * Clippy fixes. * Add the llava multimodal adapter. * Add more of the llava bits. * Add the pixtral config. * More pixtral inference. * Add the text generation bits. * Get the example to work. * Bugfix. * Run some bits of the model in f32. * Blessed version :) * Better rope frequency computations. * README update. --- candle-examples/examples/pixtral/README.md | 28 ++ candle-examples/examples/pixtral/main.rs | 336 ++++++++++++++++++ candle-nn/src/var_builder.rs | 36 +- candle-transformers/src/models/llava/mod.rs | 2 +- candle-transformers/src/models/mistral.rs | 38 +- candle-transformers/src/models/mod.rs | 1 + .../src/models/pixtral/llava.rs | 72 ++++ candle-transformers/src/models/pixtral/mod.rs | 4 + .../src/models/pixtral/vision_model.rs | 324 +++++++++++++++++ 9 files changed, 822 insertions(+), 19 deletions(-) create mode 100644 candle-examples/examples/pixtral/README.md create mode 100644 candle-examples/examples/pixtral/main.rs create mode 100644 candle-transformers/src/models/pixtral/llava.rs create mode 100644 candle-transformers/src/models/pixtral/mod.rs create mode 100644 candle-transformers/src/models/pixtral/vision_model.rs diff --git a/candle-examples/examples/pixtral/README.md b/candle-examples/examples/pixtral/README.md new file mode 100644 index 00000000..77677571 --- /dev/null +++ b/candle-examples/examples/pixtral/README.md @@ -0,0 +1,28 @@ +# pixtral + +Pixtral-12B is a 12B text+vision model. + +[Blog Post](https://mistral.ai/news/pixtral-12b/) - +[HF Model Card](https://huggingface.co/mistralai/Pixtral-12B-2409) - +[HF Community Model Card](https://huggingface.co/mistral-community/pixtral-12b). + +```bash +cargo run --profile=release-with-debug --features cuda --example pixtral -- \ + --image candle-examples/examples/flux/assets/flux-robot.jpg +``` + +``` +Describe the image. + +The image depicts a charming, rustic robot standing on a sandy beach at sunset. +The robot has a vintage, steampunk aesthetic with visible gears and mechanical +parts. It is holding a small lantern in one hand, which emits a warm glow, and +its other arm is extended forward as if reaching out or guiding the way. The +robot's body is adorned with the word "RUST" in bright orange letters, adding to +its rustic theme. + +The background features a dramatic sky filled with clouds, illuminated by the +setting sun, casting a golden hue over the scene. Gentle waves lap against the +shore, creating a serene and picturesque atmosphere. The overall mood of the +image is whimsical and nostalgic, evoking a sense of adventure and tranquility. +``` diff --git a/candle-examples/examples/pixtral/main.rs b/candle-examples/examples/pixtral/main.rs new file mode 100644 index 00000000..8e48b60b --- /dev/null +++ b/candle-examples/examples/pixtral/main.rs @@ -0,0 +1,336 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use anyhow::{Error as E, Result}; +use clap::Parser; + +use candle_transformers::models::pixtral::{vision_model, Config, Model}; + +use candle::{DType, Device, Module, Tensor}; +use candle_examples::token_output_stream::TokenOutputStream; +use candle_nn::VarBuilder; +use candle_transformers::generation::LogitsProcessor; +use hf_hub::{api::sync::Api, Repo, RepoType}; +use tokenizers::Tokenizer; + +struct TextGeneration { + model: Model, + image: Tensor, + device: Device, + tokenizer: TokenOutputStream, + logits_processor: LogitsProcessor, + repeat_penalty: f32, + repeat_last_n: usize, +} + +impl TextGeneration { + #[allow(clippy::too_many_arguments)] + fn new( + model: Model, + image: Tensor, + tokenizer: Tokenizer, + seed: u64, + temp: Option, + top_p: Option, + repeat_penalty: f32, + repeat_last_n: usize, + device: &Device, + ) -> Self { + let logits_processor = LogitsProcessor::new(seed, temp, top_p); + Self { + model, + image, + tokenizer: TokenOutputStream::new(tokenizer), + logits_processor, + repeat_penalty, + repeat_last_n, + device: device.clone(), + } + } + + fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> { + use std::io::Write; + self.tokenizer.clear(); + let mut tokens = self + .tokenizer + .tokenizer() + .encode(prompt, true) + .map_err(E::msg)? + .get_ids() + .to_vec(); + let mut generated_tokens = 0usize; + let get_token = |v| match self.tokenizer.get_token(v) { + Some(token) => Ok(token), + None => anyhow::bail!("cannot find the {v} token"), + }; + let bos_token = get_token("")?; + let eos_token = get_token("")?; + let inst_token = get_token("[INST]")?; + let end_inst_token = get_token("[/INST]")?; + let img_break = get_token("[IMG_BREAK]")?; + let img_end = get_token("[IMG_END]")?; + let start_gen = std::time::Instant::now(); + let mut pos = 0; + for index in 0..sample_len { + let logits = if index > 0 { + let context_size = if index > 0 { 1 } else { tokens.len() }; + let start_pos = tokens.len().saturating_sub(context_size); + let ctxt = &tokens[start_pos..]; + let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?; + let logits = self.model.language_model.forward(&input, pos)?; + pos += context_size; + logits + } else { + let (_b, _c, h, w) = self.image.dims4()?; + let h = h / self.model.patch_size; + let w = w / self.model.patch_size; + let image_embeds = self.model.vision_tower.forward(&self.image)?; + let image_embeds = self.model.multi_modal_projector.forward(&image_embeds)?; + println!("generated image embeddings {image_embeds:?}"); + let image_embeds = image_embeds.to_dtype(self.model.dtype)?; + for &t in tokens.iter() { + if let Some(t) = self.tokenizer.next_token(t)? { + print!("{t}") + } + } + std::io::stdout().flush()?; + + let break_embeds = { + let input = Tensor::new(&[img_break], &self.device)?.unsqueeze(0)?; + self.model.language_model.embed_tokens().forward(&input)? + }; + let start_embeds = { + let mut in_tokens = vec![bos_token, inst_token]; + in_tokens.extend_from_slice(tokens.as_slice()); + let input = Tensor::new(in_tokens.as_slice(), &self.device)?.unsqueeze(0)?; + self.model.language_model.embed_tokens().forward(&input)? + }; + let end_embeds = { + let input = + Tensor::new(&[img_end, end_inst_token], &self.device)?.unsqueeze(0)?; + self.model.language_model.embed_tokens().forward(&input)? + }; + let mut input_embeds = vec![start_embeds]; + for h_idx in 0..h { + if h_idx > 0 { + input_embeds.push(break_embeds.clone()) + } + let row = image_embeds.narrow(1, h_idx * w, w)?; + input_embeds.push(row); + } + input_embeds.push(end_embeds); + + let input_embeds = Tensor::cat(&input_embeds, 1)?; + let logits = self + .model + .language_model + .forward_embeds(&input_embeds, None, pos)?; + pos += input_embeds.dim(1)?; + logits + }; + let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?; + let logits = if self.repeat_penalty == 1. { + logits + } else { + let start_at = tokens.len().saturating_sub(self.repeat_last_n); + candle_transformers::utils::apply_repeat_penalty( + &logits, + self.repeat_penalty, + &tokens[start_at..], + )? + }; + + let next_token = self.logits_processor.sample(&logits)?; + tokens.push(next_token); + generated_tokens += 1; + if next_token == eos_token { + break; + } + if let Some(t) = self.tokenizer.next_token(next_token)? { + print!("{t}"); + std::io::stdout().flush()?; + } + } + let dt = start_gen.elapsed(); + if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? { + print!("{rest}"); + } + std::io::stdout().flush()?; + println!( + "\n{generated_tokens} tokens generated ({:.2} token/s)", + generated_tokens as f64 / dt.as_secs_f64(), + ); + Ok(()) + } +} + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, + + /// Enable tracing (generates a trace-timestamp.json file). + #[arg(long)] + tracing: bool, + + #[arg(long, default_value = "Describe the image.\n")] + prompt: String, + + /// The temperature used to generate samples. + #[arg(long)] + temperature: Option, + + /// Nucleus sampling probability cutoff. + #[arg(long)] + top_p: Option, + + /// The seed to use when generating random samples. + #[arg(long, default_value_t = 299792458)] + seed: u64, + + /// The length of the sample to generate (in tokens). + #[arg(long, short = 'n', default_value_t = 10000)] + sample_len: usize, + + #[arg(long)] + model_id: Option, + + #[arg(long, default_value = "main")] + revision: String, + + #[arg(long)] + tokenizer_file: Option, + + #[arg(long)] + config_file: Option, + + #[arg(long)] + weight_files: Option, + + /// Penalty to be applied for repeating tokens, 1. means no penalty. + #[arg(long, default_value_t = 1.1)] + repeat_penalty: f32, + + /// The context size to consider for the repeat penalty. + #[arg(long, default_value_t = 64)] + repeat_last_n: usize, + + #[arg(long)] + image: String, + + #[arg(long)] + vision_only: bool, +} + +fn main() -> Result<()> { + use tracing_chrome::ChromeLayerBuilder; + use tracing_subscriber::prelude::*; + + let args = Args::parse(); + let _guard = if args.tracing { + let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); + tracing_subscriber::registry().with(chrome_layer).init(); + Some(guard) + } else { + None + }; + println!( + "avx: {}, neon: {}, simd128: {}, f16c: {}", + candle::utils::with_avx(), + candle::utils::with_neon(), + candle::utils::with_simd128(), + candle::utils::with_f16c() + ); + println!( + "temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}", + args.temperature.unwrap_or(0.), + args.repeat_penalty, + args.repeat_last_n + ); + + let start = std::time::Instant::now(); + let api = Api::new()?; + let model_id = match &args.model_id { + Some(model_id) => model_id.to_string(), + None => "mistral-community/pixtral-12b".to_string(), + }; + let repo = api.repo(Repo::with_revision( + model_id, + RepoType::Model, + args.revision, + )); + let tokenizer_filename = match args.tokenizer_file { + Some(file) => std::path::PathBuf::from(file), + None => repo.get("tokenizer.json")?, + }; + let filenames = match args.weight_files { + Some(files) => files + .split(',') + .map(std::path::PathBuf::from) + .collect::>(), + None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?, + }; + println!("retrieved the files in {:?}", start.elapsed()); + + let device = candle_examples::device(args.cpu)?; + let dtype = if device.supports_bf16() && !args.vision_only { + DType::BF16 + } else { + DType::F32 + }; + let config: Config = match args.config_file { + Some(config_file) => serde_json::from_slice(&std::fs::read(config_file)?)?, + None => { + let config_file = repo.get("config.json")?; + serde_json::from_slice(&std::fs::read(config_file)?)? + } + }; + let image = if args.image.ends_with(".safetensors") { + match candle::safetensors::load(&args.image, &device)?.remove("img") { + None => anyhow::bail!("no img tensor in {}", args.image), + Some(v) => v, + } + } else { + candle_examples::imagenet::load_image_with_std_mean( + &args.image, + 1024, + &[0.48145466, 0.4578275, 0.40821073], + &[0.26862954, 0.261_302_6, 0.275_777_1], + )? + }; + let image = image.to_device(&device)?.unsqueeze(0)?; + println!("loaded image with shape {:?}", image); + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? }; + + if args.vision_only { + let start = std::time::Instant::now(); + let model = vision_model::Model::new(&config.vision_config, vb.pp("vision_tower"))?; + println!("loaded the model in {:?}", start.elapsed()); + let embs = model.forward(&image)?; + println!("EMBS\n{embs}"); + } else { + let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; + let start = std::time::Instant::now(); + let model = Model::new(&config, vb)?; + println!("loaded the model in {:?}", start.elapsed()); + let mut pipeline = TextGeneration::new( + model, + image, + tokenizer, + args.seed, + args.temperature, + args.top_p, + args.repeat_penalty, + args.repeat_last_n, + &device, + ); + pipeline.run(&args.prompt, args.sample_len)?; + } + + Ok(()) +} diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs index f6e6160b..00669468 100644 --- a/candle-nn/src/var_builder.rs +++ b/candle-nn/src/var_builder.rs @@ -14,6 +14,7 @@ use std::sync::Arc; pub struct VarBuilderArgs<'a, B: Backend> { data: Arc>, path: Vec, + pub dtype: DType, _phantom: std::marker::PhantomData<&'a B>, } @@ -22,6 +23,7 @@ impl<'a, B: Backend> Clone for VarBuilderArgs<'a, B> { Self { data: self.data.clone(), path: self.path.clone(), + dtype: self.dtype, _phantom: self._phantom, } } @@ -33,7 +35,6 @@ pub type VarBuilder<'a> = VarBuilderArgs<'a, Box>; struct TensorData { backend: B, - pub dtype: DType, pub device: Device, } @@ -95,12 +96,12 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> { pub fn new_with_args(backend: B, dtype: DType, dev: &Device) -> Self { let data = TensorData { backend, - dtype, device: dev.clone(), }; Self { data: Arc::new(data), path: vec![], + dtype, _phantom: std::marker::PhantomData, } } @@ -115,6 +116,7 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> { Self { data: self.data.clone(), path: vec![], + dtype: self.dtype, _phantom: std::marker::PhantomData, } } @@ -124,6 +126,7 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> { Self { data: self.data.clone(), path: vec![prefix.to_string()], + dtype: self.dtype, _phantom: std::marker::PhantomData, } } @@ -136,6 +139,7 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> { Self { data: self.data.clone(), path, + dtype: self.dtype, _phantom: std::marker::PhantomData, } } @@ -152,7 +156,17 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> { /// The dtype used by default. pub fn dtype(&self) -> DType { - self.data.dtype + self.dtype + } + + /// Clone the VarBuilder tweaking its dtype + pub fn to_dtype(&self, dtype: DType) -> Self { + Self { + data: self.data.clone(), + path: self.path.clone(), + dtype, + _phantom: std::marker::PhantomData, + } } fn path(&self, tensor_name: &str) -> String { @@ -178,7 +192,7 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> { name: &str, hints: B::Hints, ) -> Result { - self.get_with_hints_dtype(s, name, hints, self.data.dtype) + self.get_with_hints_dtype(s, name, hints, self.dtype) } /// Retrieve the tensor associated with the given name at the current path. @@ -460,14 +474,11 @@ impl<'a> VarBuilder<'a> { dtype: DType, device: Device, ) -> Self { - let data = TensorData { - backend, - dtype, - device, - }; + let data = TensorData { backend, device }; Self { data: Arc::new(data), path: vec![], + dtype, _phantom: std::marker::PhantomData, } } @@ -567,13 +578,10 @@ impl<'a> VarBuilder<'a> { let path = self.path.clone(); let backend = Rename::new(self, renamer); let backend: Box = Box::new(backend); - let data = TensorData { - backend, - dtype, - device, - }; + let data = TensorData { backend, device }; Self { data: Arc::new(data), + dtype, path, _phantom: std::marker::PhantomData, } diff --git a/candle-transformers/src/models/llava/mod.rs b/candle-transformers/src/models/llava/mod.rs index caa8737a..1ed3b50c 100644 --- a/candle-transformers/src/models/llava/mod.rs +++ b/candle-transformers/src/models/llava/mod.rs @@ -279,7 +279,7 @@ impl LLaVA { (), ))? } else { - todo!("not implemented in original python LLaVA yet") + bail!("not implemented in original python LLaVA yet") }; let new_image_feature = if mm_patch_merge_type.contains("unpad") { let new_image_feature = new_image_feature diff --git a/candle-transformers/src/models/mistral.rs b/candle-transformers/src/models/mistral.rs index 7e3b21c9..e8f7a7c4 100644 --- a/candle-transformers/src/models/mistral.rs +++ b/candle-transformers/src/models/mistral.rs @@ -4,19 +4,29 @@ use candle::{DType, Device, Module, Result, Tensor, D}; use candle_nn::{Activation, VarBuilder}; use std::sync::Arc; +fn default_num_attention_heads() -> usize { + 32 +} + fn default_use_flash_attn() -> bool { false } +fn default_hidden_act() -> candle_nn::Activation { + candle_nn::Activation::Silu +} + #[derive(Debug, Clone, PartialEq, serde::Deserialize)] pub struct Config { pub vocab_size: usize, pub hidden_size: usize, pub intermediate_size: usize, pub num_hidden_layers: usize, + #[serde(default = "default_num_attention_heads")] pub num_attention_heads: usize, pub head_dim: Option, pub num_key_value_heads: usize, + #[serde(default = "default_hidden_act")] pub hidden_act: Activation, pub max_position_embeddings: usize, pub rms_norm_eps: f64, @@ -107,14 +117,14 @@ impl RotaryEmbedding { .map(|i| 1f32 / rope_theta.powf(i as f32 / dim as f32)) .collect(); let inv_freq_len = inv_freq.len(); - let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?; + let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(DType::F32)?; let t = Tensor::arange(0u32, max_seq_len as u32, dev)? - .to_dtype(dtype)? + .to_dtype(DType::F32)? .reshape((max_seq_len, 1))?; let freqs = t.matmul(&inv_freq)?; Ok(Self { - sin: freqs.sin()?, - cos: freqs.cos()?, + sin: freqs.sin()?.to_dtype(dtype)?, + cos: freqs.cos()?.to_dtype(dtype)?, }) } @@ -404,6 +414,10 @@ impl Model { .to_dtype(self.dtype) } + pub fn embed_tokens(&self) -> &candle_nn::Embedding { + &self.embed_tokens + } + pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result { let (_b_size, seq_len) = input_ids.dims2()?; let attention_mask = if seq_len <= 1 { @@ -421,6 +435,22 @@ impl Model { .apply(&self.lm_head) } + pub fn forward_embeds( + &mut self, + xs: &Tensor, + attn_mask: Option<&Tensor>, + seqlen_offset: usize, + ) -> Result { + let (_b_size, seq_len, _) = xs.dims3()?; + let mut xs = xs.clone(); + for layer in self.layers.iter_mut() { + xs = layer.forward(&xs, attn_mask, seqlen_offset)? + } + xs.narrow(1, seq_len - 1, 1)? + .apply(&self.norm)? + .apply(&self.lm_head) + } + pub fn clear_kv_cache(&mut self) { for layer in self.layers.iter_mut() { layer.clear_kv_cache() diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index bba701bd..09876503 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -51,6 +51,7 @@ pub mod parler_tts; pub mod persimmon; pub mod phi; pub mod phi3; +pub mod pixtral; pub mod quantized_blip; pub mod quantized_blip_text; pub mod quantized_llama; diff --git a/candle-transformers/src/models/pixtral/llava.rs b/candle-transformers/src/models/pixtral/llava.rs new file mode 100644 index 00000000..33e0aca0 --- /dev/null +++ b/candle-transformers/src/models/pixtral/llava.rs @@ -0,0 +1,72 @@ +use candle::{Module, Result, Tensor}; +use candle_nn::{linear, Linear, VarBuilder}; + +use super::vision_model; +use crate::models::mistral; + +#[derive(serde::Deserialize, Debug, Clone)] +pub struct Config { + pub projector_hidden_act: candle_nn::Activation, + pub text_config: mistral::Config, + pub vision_config: vision_model::Config, + pub image_token_index: usize, + pub image_seq_length: usize, +} + +#[derive(Debug, Clone)] +pub struct MultiModalProjector { + linear_1: Linear, + act: candle_nn::Activation, + linear_2: Linear, +} + +impl MultiModalProjector { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result { + let (hidden_v, hidden_t) = (cfg.vision_config.hidden_size, cfg.text_config.hidden_size); + let linear_1 = linear(hidden_v, hidden_t, vb.pp("linear_1"))?; + let linear_2 = linear(hidden_t, hidden_t, vb.pp("linear_2"))?; + Ok(Self { + linear_1, + act: cfg.projector_hidden_act, + linear_2, + }) + } +} + +impl Module for MultiModalProjector { + fn forward(&self, xs: &Tensor) -> Result { + xs.apply(&self.linear_1)? + .apply(&self.act)? + .apply(&self.linear_2) + } +} + +#[derive(Debug, Clone)] +pub struct Model { + pub multi_modal_projector: MultiModalProjector, + pub language_model: mistral::Model, + pub vision_tower: vision_model::Model, + pub patch_size: usize, + pub dtype: candle::DType, +} + +impl Model { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result { + let language_model = mistral::Model::new(&cfg.text_config, vb.pp("language_model"))?; + let vision_tower = vision_model::Model::new( + &cfg.vision_config, + vb.pp("vision_tower").to_dtype(candle::DType::F32), + )?; + let multi_modal_projector = MultiModalProjector::new( + cfg, + vb.pp("multi_modal_projector").to_dtype(candle::DType::F32), + )?; + Ok(Self { + multi_modal_projector, + language_model, + vision_tower, + patch_size: cfg.vision_config.patch_size, + dtype: vb.dtype(), + }) + } +} diff --git a/candle-transformers/src/models/pixtral/mod.rs b/candle-transformers/src/models/pixtral/mod.rs new file mode 100644 index 00000000..9d0eccfb --- /dev/null +++ b/candle-transformers/src/models/pixtral/mod.rs @@ -0,0 +1,4 @@ +pub mod llava; +pub mod vision_model; + +pub use llava::{Config, Model}; diff --git a/candle-transformers/src/models/pixtral/vision_model.rs b/candle-transformers/src/models/pixtral/vision_model.rs new file mode 100644 index 00000000..20d8f082 --- /dev/null +++ b/candle-transformers/src/models/pixtral/vision_model.rs @@ -0,0 +1,324 @@ +use candle::{DType, Module, Result, Tensor, D}; +use candle_nn::{linear_b, rms_norm, Linear, RmsNorm, VarBuilder}; + +fn default_act() -> candle_nn::Activation { + candle_nn::Activation::Gelu +} + +fn default_hidden_size() -> usize { + 1024 +} + +fn default_intermediate_size() -> usize { + 4096 +} + +fn default_num_channels() -> usize { + 3 +} + +fn default_num_hidden_layers() -> usize { + 24 +} + +fn default_num_attention_heads() -> usize { + 16 +} + +#[derive(serde::Deserialize, Debug, Clone)] +pub struct Config { + #[serde(default = "default_hidden_size")] + pub hidden_size: usize, + #[serde(default = "default_num_channels")] + pub num_channels: usize, + pub image_size: usize, + pub patch_size: usize, + pub rope_theta: f64, + #[serde(default = "default_intermediate_size")] + pub intermediate_size: usize, + #[serde(default = "default_num_hidden_layers")] + pub num_hidden_layers: usize, + pub head_dim: Option, + #[serde(default = "default_num_attention_heads")] + pub num_attention_heads: usize, + #[serde(default = "default_act")] + pub hidden_act: candle_nn::Activation, +} + +impl Config { + pub fn pixtral_12b_2409() -> Self { + Self { + hidden_size: 1024, + num_channels: 3, + image_size: 1024, + patch_size: 16, + rope_theta: 10000.0, + intermediate_size: 4096, + num_hidden_layers: 24, + num_attention_heads: 16, + head_dim: None, + // Default + hidden_act: candle_nn::Activation::Gelu, + } + } + + fn head_dim(&self) -> usize { + self.head_dim + .unwrap_or(self.hidden_size / self.num_attention_heads) + } +} + +#[derive(Debug, Clone)] +struct Attention { + q_proj: Linear, + k_proj: Linear, + v_proj: Linear, + o_proj: Linear, + scale: f64, + num_heads: usize, + head_dim: usize, +} + +impl Attention { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let h = cfg.hidden_size; + let num_heads = cfg.num_attention_heads; + let head_dim = cfg.head_dim(); + let q_proj = linear_b(h, h, false, vb.pp("q_proj"))?; + let k_proj = linear_b(h, h, false, vb.pp("k_proj"))?; + let v_proj = linear_b(h, h, false, vb.pp("v_proj"))?; + let o_proj = linear_b(h, h, false, vb.pp("o_proj"))?; + let scale = (head_dim as f64).powf(-0.5); + Ok(Self { + q_proj, + k_proj, + v_proj, + o_proj, + scale, + num_heads, + head_dim, + }) + } + + fn forward( + &self, + xs: &Tensor, + emb: &RotaryEmbedding, + attention_mask: Option<&Tensor>, + ) -> Result { + let (b, patches, _) = xs.dims3()?; + let query_states = xs.apply(&self.q_proj)?; + let key_states = xs.apply(&self.k_proj)?; + let value_states = xs.apply(&self.v_proj)?; + + let shape = (b, patches, self.num_heads, self.head_dim); + let query_states = query_states.reshape(shape)?.transpose(1, 2)?.contiguous()?; + let key_states = key_states.reshape(shape)?.transpose(1, 2)?.contiguous()?; + let value_states = value_states.reshape(shape)?.transpose(1, 2)?.contiguous()?; + + let (query_states, key_states) = emb.apply_rotary_emb_qkv(&query_states, &key_states)?; + let attn_weights = (query_states.matmul(&key_states.t()?)? * self.scale)?; + + let attn_weights = match attention_mask { + None => attn_weights, + Some(mask) => attn_weights.broadcast_add(mask)?, + }; + + let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?; + attn_weights + .matmul(&value_states)? + .transpose(1, 2)? + .reshape((b, patches, ()))? + .apply(&self.o_proj) + } +} + +#[derive(Debug, Clone)] +struct Mlp { + gate_proj: Linear, + up_proj: Linear, + down_proj: Linear, + act_fn: candle_nn::Activation, +} + +impl Mlp { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let (h, i) = (cfg.hidden_size, cfg.intermediate_size); + let gate_proj = linear_b(h, i, false, vb.pp("gate_proj"))?; + let up_proj = linear_b(h, i, false, vb.pp("up_proj"))?; + let down_proj = linear_b(i, h, false, vb.pp("down_proj"))?; + Ok(Self { + gate_proj, + up_proj, + down_proj, + act_fn: cfg.hidden_act, + }) + } +} + +impl Module for Mlp { + fn forward(&self, xs: &Tensor) -> Result { + (xs.apply(&self.gate_proj)?.apply(&self.act_fn)? * xs.apply(&self.up_proj))? + .apply(&self.down_proj) + } +} + +#[derive(Debug, Clone)] +struct AttentionLayer { + attention_norm: RmsNorm, + feed_forward: Mlp, + attention: Attention, + ffn_norm: RmsNorm, +} + +impl AttentionLayer { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let attention_norm = rms_norm(cfg.hidden_size, 1e-5, vb.pp("attention_norm"))?; + let feed_forward = Mlp::new(cfg, vb.pp("feed_forward"))?; + let attention = Attention::new(cfg, vb.pp("attention"))?; + let ffn_norm = rms_norm(cfg.hidden_size, 1e-5, vb.pp("ffn_norm"))?; + Ok(Self { + attention_norm, + feed_forward, + attention, + ffn_norm, + }) + } + + fn forward( + &self, + xs: &Tensor, + emb: &RotaryEmbedding, + attention_mask: Option<&Tensor>, + ) -> Result { + let residual = xs; + let xs = self + .attention + .forward(&xs.apply(&self.attention_norm)?, emb, attention_mask)?; + let xs = (residual + xs)?; + let residual = &xs; + let xs = xs.apply(&self.ffn_norm)?.apply(&self.feed_forward)?; + xs + residual + } +} + +#[derive(Debug, Clone)] +struct Transformer { + layers: Vec, +} + +impl Transformer { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let mut layers = Vec::with_capacity(cfg.num_hidden_layers); + let vb = vb.pp("layers"); + for layer_idx in 0..cfg.num_hidden_layers { + let layer = AttentionLayer::new(cfg, vb.pp(layer_idx))?; + layers.push(layer) + } + Ok(Self { layers }) + } + + fn forward( + &self, + xs: &Tensor, + emb: &RotaryEmbedding, + attention_mask: Option<&Tensor>, + ) -> Result { + let mut xs = xs.clone(); + for layer in self.layers.iter() { + xs = layer.forward(&xs, emb, attention_mask)? + } + Ok(xs) + } +} + +#[derive(Debug, Clone)] +struct RotaryEmbedding { + cos: Tensor, + sin: Tensor, +} + +impl RotaryEmbedding { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let dtype = vb.dtype(); + let dev = vb.device(); + let dim = cfg.head_dim(); + let rope_theta = cfg.rope_theta as f32; + let max_patches_per_side = cfg.image_size / cfg.patch_size; + let freqs: Vec<_> = (0..dim) + .step_by(2) + .map(|i| 1f32 / rope_theta.powf(i as f32 / dim as f32)) + .collect(); + let freqs_h = freqs.iter().step_by(2).copied().collect::>(); + let freqs_h = Tensor::new(freqs_h, dev)?; + let freqs_w = freqs.iter().skip(1).step_by(2).copied().collect::>(); + let freqs_w = Tensor::new(freqs_w, dev)?; + let h = Tensor::arange(0u32, max_patches_per_side as u32, dev)?.to_dtype(DType::F32)?; + let w = Tensor::arange(0u32, max_patches_per_side as u32, dev)?.to_dtype(DType::F32)?; + let freqs_h = h.unsqueeze(1)?.matmul(&freqs_h.unsqueeze(0)?)?; + let freqs_w = w.unsqueeze(1)?.matmul(&freqs_w.unsqueeze(0)?)?; + let inv_freq = Tensor::cat( + &[ + freqs_h.unsqueeze(1)?.repeat((1, max_patches_per_side, 1))?, + freqs_w.unsqueeze(0)?.repeat((max_patches_per_side, 1, 1))?, + ], + D::Minus1, + )? + .reshape(((), dim / 2))?; + let cos = inv_freq.cos()?.to_dtype(dtype)?; + let sin = inv_freq.sin()?.to_dtype(dtype)?; + Ok(Self { cos, sin }) + } + + fn apply_rotary_emb_qkv(&self, q: &Tensor, k: &Tensor) -> Result<(Tensor, Tensor)> { + let (_b_sz, _h, _seq_len, _n_embd) = q.dims4()?; + let cos = &self.cos; + let sin = &self.sin; + let q_embed = candle_nn::rotary_emb::rope(q, cos, sin)?; + let k_embed = candle_nn::rotary_emb::rope(k, cos, sin)?; + Ok((q_embed, k_embed)) + } +} + +#[derive(Debug, Clone)] +pub struct Model { + patch_conv: candle_nn::Conv2d, + ln_pre: RmsNorm, + transformer: Transformer, + patch_positional_embedding: RotaryEmbedding, +} + +impl Model { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result { + let conv2d_cfg = candle_nn::Conv2dConfig { + stride: cfg.patch_size, + ..Default::default() + }; + let patch_conv = candle_nn::conv2d_no_bias( + cfg.num_channels, + cfg.hidden_size, + cfg.patch_size, + conv2d_cfg, + vb.pp("patch_conv"), + )?; + let ln_pre = candle_nn::rms_norm(cfg.hidden_size, 1e-5, vb.pp("ln_pre"))?; + let transformer = Transformer::new(cfg, vb.pp("transformer"))?; + let patch_positional_embedding = + RotaryEmbedding::new(cfg, vb.pp("patch_positional_embedding"))?; + Ok(Self { + patch_conv, + ln_pre, + transformer, + patch_positional_embedding, + }) + } +} + +impl Module for Model { + fn forward(&self, xs: &Tensor) -> Result { + let patch_embeds = xs.apply(&self.patch_conv)?; + let patch_embeds = patch_embeds.flatten_from(2)?.t()?.apply(&self.ln_pre)?; + self.transformer + .forward(&patch_embeds, &self.patch_positional_embedding, None) + } +}