From 86613c00e216750f32a326dbff5cc993d5e0067e Mon Sep 17 00:00:00 2001 From: Jani Monoses Date: Thu, 29 Aug 2024 16:38:58 +0300 Subject: [PATCH] MobileCLIP models S1 and S2 (#2454) * Allow loading images with given std and mean * OpenCLIP text encoder component * Two MobileCLIP models * Clippy fixes. --------- Co-authored-by: Laurent --- candle-examples/examples/mobileclip/README.md | 28 ++ candle-examples/examples/mobileclip/main.rs | 192 +++++++++++++ candle-examples/examples/mobilenetv4/main.rs | 5 +- candle-examples/src/imagenet.rs | 35 ++- candle-transformers/src/models/mobileclip.rs | 89 ++++++ candle-transformers/src/models/mod.rs | 2 + .../src/models/openclip/mod.rs | 1 + .../src/models/openclip/text_model.rs | 266 ++++++++++++++++++ 8 files changed, 608 insertions(+), 10 deletions(-) create mode 100644 candle-examples/examples/mobileclip/README.md create mode 100644 candle-examples/examples/mobileclip/main.rs create mode 100644 candle-transformers/src/models/mobileclip.rs create mode 100644 candle-transformers/src/models/openclip/mod.rs create mode 100644 candle-transformers/src/models/openclip/text_model.rs diff --git a/candle-examples/examples/mobileclip/README.md b/candle-examples/examples/mobileclip/README.md new file mode 100644 index 00000000..a3869b25 --- /dev/null +++ b/candle-examples/examples/mobileclip/README.md @@ -0,0 +1,28 @@ +# candle-mobileclip + +MobileCLIP is family of efficient CLIP-like models using FastViT-based image encoders. + +See [MobileCLIP: Fast Image-Text Models through Multi-Modal Reinforced Training](https://arxiv.org/abs/2311.17049) + + +## Running on an example on cpu + +``` +$ cargo run --example mobileclip --release -- --images "candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg","candle-examples/examples/yolo-v8/assets/bike.jpg" --cpu --sequences "a cycling race","a photo of two cats","a robot holding a candle" + +softmax_image_vec: [2.4819004e-5, 3.81081e-6, 0.9999714, 0.9999738, 2.382714e-5, 2.3317718e-6] + + +Results for image: candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg + +Probability: 0.0025% Text: a cycling race +Probability: 0.0004% Text: a photo of two cats +Probability: 99.9971% Text: a robot holding a candle + + +Results for image: candle-examples/examples/yolo-v8/assets/bike.jpg + +Probability: 99.9974% Text: a cycling race +Probability: 0.0024% Text: a photo of two cats +Probability: 0.0002% Text: a robot holding a candle +``` diff --git a/candle-examples/examples/mobileclip/main.rs b/candle-examples/examples/mobileclip/main.rs new file mode 100644 index 00000000..d505fc7c --- /dev/null +++ b/candle-examples/examples/mobileclip/main.rs @@ -0,0 +1,192 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use anyhow::Error as E; +use clap::{Parser, ValueEnum}; + +use candle::{DType, Device, Tensor}; +use candle_nn::{ops::softmax, VarBuilder}; +use candle_transformers::models::mobileclip; + +use tokenizers::Tokenizer; + +#[derive(Clone, Copy, Debug, ValueEnum)] +enum Which { + S1, + S2, +} + +impl Which { + fn model_name(&self) -> String { + let name = match self { + Self::S1 => "S1", + Self::S2 => "S2", + }; + format!("apple/MobileCLIP-{}-OpenCLIP", name) + } + + fn config(&self) -> mobileclip::MobileClipConfig { + match self { + Self::S1 => mobileclip::MobileClipConfig::s1(), + Self::S2 => mobileclip::MobileClipConfig::s2(), + } + } +} + +#[derive(Parser)] +struct Args { + #[arg(long, use_value_delimiter = true)] + images: Option>, + + #[arg(long)] + cpu: bool, + + /// Use the pytorch weights rather than the safetensors ones + #[arg(long)] + use_pth: bool, + + #[arg(long, use_value_delimiter = true)] + sequences: Option>, + + #[arg(value_enum, long, default_value_t=Which::S1)] + which: Which, +} + +fn load_images>( + paths: &Vec, + image_size: usize, +) -> anyhow::Result { + let mut images = vec![]; + + for path in paths { + let tensor = candle_examples::imagenet::load_image_with_std_mean( + path, + image_size, + &[0.0, 0.0, 0.0], + &[1.0, 1.0, 1.0], + )?; + images.push(tensor); + } + + let images = Tensor::stack(&images, 0)?; + + Ok(images) +} + +pub fn main() -> anyhow::Result<()> { + let args = Args::parse(); + + let model_name = args.which.model_name(); + + let api = hf_hub::api::sync::Api::new()?; + let api = api.model(model_name); + + let model_file = if args.use_pth { + api.get("open_clip_pytorch_model.bin")? + } else { + api.get("open_clip_model.safetensors")? + }; + + let tokenizer = api.get("tokenizer.json")?; + + let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?; + + let config = &args.which.config(); + + let device = candle_examples::device(args.cpu)?; + + let vec_imgs = match args.images { + Some(imgs) => imgs, + None => vec![ + "candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg".to_string(), + "candle-examples/examples/yolo-v8/assets/bike.jpg".to_string(), + ], + }; + + let images = load_images(&vec_imgs, config.image_size)?.to_device(&device)?; + + let vb = if args.use_pth { + VarBuilder::from_pth(&model_file, DType::F32, &device)? + } else { + unsafe { VarBuilder::from_mmaped_safetensors(&[model_file.clone()], DType::F32, &device)? } + }; + + let model = mobileclip::MobileClipModel::new(vb, config)?; + + let (input_ids, vec_seq) = tokenize_sequences(args.sequences, &tokenizer, &device)?; + + let (_logits_per_text, logits_per_image) = model.forward(&images, &input_ids)?; + + let softmax_image = softmax(&logits_per_image, 1)?; + + let softmax_image_vec = softmax_image.flatten_all()?.to_vec1::()?; + + println!("softmax_image_vec: {:?}", softmax_image_vec); + + let probability_vec = softmax_image_vec + .iter() + .map(|v| v * 100.0) + .collect::>(); + + let probability_per_image = probability_vec.len() / vec_imgs.len(); + + for (i, img) in vec_imgs.iter().enumerate() { + let start = i * probability_per_image; + let end = start + probability_per_image; + let prob = &probability_vec[start..end]; + println!("\n\nResults for image: {}\n", img); + + for (i, p) in prob.iter().enumerate() { + println!("Probability: {:.4}% Text: {}", p, vec_seq[i]); + } + } + + Ok(()) +} + +pub fn tokenize_sequences( + sequences: Option>, + tokenizer: &Tokenizer, + device: &Device, +) -> anyhow::Result<(Tensor, Vec)> { + // let pad_id = *tokenizer + // .get_vocab(true) + // .get("<|endoftext|>") + // .ok_or(E::msg("No pad token"))?; + + // The model does not work well if the text is padded using the <|endoftext|> token, using 0 + // as the original OpenCLIP code. + let pad_id = 0; + + let vec_seq = match sequences { + Some(seq) => seq, + None => vec![ + "a cycling race".to_string(), + "a photo of two cats".to_string(), + "a robot holding a candle".to_string(), + ], + }; + + let mut tokens = vec![]; + + for seq in vec_seq.clone() { + let encoding = tokenizer.encode(seq, true).map_err(E::msg)?; + tokens.push(encoding.get_ids().to_vec()); + } + + let max_len = tokens.iter().map(|v| v.len()).max().unwrap_or(0); + // Pad the sequences to have the same length + for token_vec in tokens.iter_mut() { + let len_diff = max_len - token_vec.len(); + if len_diff > 0 { + token_vec.extend(vec![pad_id; len_diff]); + } + } + + let input_ids = Tensor::new(tokens, device)?; + + Ok((input_ids, vec_seq)) +} diff --git a/candle-examples/examples/mobilenetv4/main.rs b/candle-examples/examples/mobilenetv4/main.rs index 26c0dad9..c31b91e6 100644 --- a/candle-examples/examples/mobilenetv4/main.rs +++ b/candle-examples/examples/mobilenetv4/main.rs @@ -72,8 +72,9 @@ pub fn main() -> anyhow::Result<()> { let device = candle_examples::device(args.cpu)?; - let image = candle_examples::imagenet::load_image(args.image, args.which.resolution())? - .to_device(&device)?; + let image = + candle_examples::imagenet::load_image(args.image, args.which.resolution() as usize)? + .to_device(&device)?; println!("loaded image {image:?}"); let model_file = match args.model { diff --git a/candle-examples/src/imagenet.rs b/candle-examples/src/imagenet.rs index 6fcda424..a3b12423 100644 --- a/candle-examples/src/imagenet.rs +++ b/candle-examples/src/imagenet.rs @@ -1,23 +1,42 @@ use candle::{Device, Result, Tensor}; -/// Loads an image from disk using the image crate at the requested resolution. -// This returns a tensor with shape (3, res, res). imagenet normalization is applied. -pub fn load_image>(p: P, res: u32) -> Result { +pub const IMAGENET_MEAN: [f32; 3] = [0.485f32, 0.456, 0.406]; +pub const IMAGENET_STD: [f32; 3] = [0.229f32, 0.224, 0.225]; + +/// Loads an image from disk using the image crate at the requested resolution, +/// using the given std and mean parameters. +/// This returns a tensor with shape (3, res, res). imagenet normalization is applied. + +pub fn load_image_with_std_mean>( + p: P, + res: usize, + mean: &[f32; 3], + std: &[f32; 3], +) -> Result { let img = image::ImageReader::open(p)? .decode() .map_err(candle::Error::wrap)? - .resize_to_fill(res, res, image::imageops::FilterType::Triangle); + .resize_to_fill( + res as u32, + res as u32, + image::imageops::FilterType::Triangle, + ); let img = img.to_rgb8(); let data = img.into_raw(); - let data = Tensor::from_vec(data, (res as usize, res as usize, 3), &Device::Cpu)? - .permute((2, 0, 1))?; - let mean = Tensor::new(&[0.485f32, 0.456, 0.406], &Device::Cpu)?.reshape((3, 1, 1))?; - let std = Tensor::new(&[0.229f32, 0.224, 0.225], &Device::Cpu)?.reshape((3, 1, 1))?; + let data = Tensor::from_vec(data, (res, res, 3), &Device::Cpu)?.permute((2, 0, 1))?; + let mean = Tensor::new(mean, &Device::Cpu)?.reshape((3, 1, 1))?; + let std = Tensor::new(std, &Device::Cpu)?.reshape((3, 1, 1))?; (data.to_dtype(candle::DType::F32)? / 255.)? .broadcast_sub(&mean)? .broadcast_div(&std) } +/// Loads an image from disk using the image crate at the requested resolution. +/// This returns a tensor with shape (3, res, res). imagenet normalization is applied. +pub fn load_image>(p: P, res: usize) -> Result { + load_image_with_std_mean(p, res, &IMAGENET_MEAN, &IMAGENET_STD) +} + /// Loads an image from disk using the image crate, this returns a tensor with shape /// (3, 224, 224). imagenet normalization is applied. pub fn load_image224>(p: P) -> Result { diff --git a/candle-transformers/src/models/mobileclip.rs b/candle-transformers/src/models/mobileclip.rs new file mode 100644 index 00000000..4953d835 --- /dev/null +++ b/candle-transformers/src/models/mobileclip.rs @@ -0,0 +1,89 @@ +use super::fastvit; +use super::openclip::text_model; +use candle::{Result, Tensor, D}; +use candle_nn::{Func, VarBuilder}; + +#[derive(Clone, Debug)] +pub struct MobileClipModel { + text_model: text_model::OpenClipTextTransformer, + vision_model: Func<'static>, + text_projection: Tensor, + logit_scale: Tensor, +} + +#[derive(Clone, Debug)] +pub struct MobileClipConfig { + pub text_config: text_model::Config, + pub vision_config: fastvit::Config, + pub image_size: usize, +} + +impl MobileClipConfig { + pub fn s1() -> Self { + let text_config = text_model::Config::vit_base_patch32(); + let vision_config = fastvit::Config::mci1(); + + Self { + text_config, + vision_config, + image_size: 256, + } + } + pub fn s2() -> Self { + let text_config = text_model::Config::vit_base_patch32(); + let vision_config = fastvit::Config::mci2(); + + Self { + text_config, + vision_config, + image_size: 256, + } + } +} + +impl MobileClipModel { + pub fn new(vs: VarBuilder, c: &MobileClipConfig) -> Result { + let vision_model = fastvit::fastvit(&c.vision_config, 512, vs.pp("visual.trunk"))?; + let text_model = text_model::OpenClipTextTransformer::new(vs.pp("text"), &c.text_config)?; + + let text_projection = vs.get( + (c.text_config.embed_dim, c.text_config.projection_dim), + "text.text_projection", + )?; + + let logit_scale = vs.get(&[], "logit_scale")?; + Ok(Self { + text_model, + vision_model, + text_projection, + logit_scale, + }) + } + + pub fn get_text_features(&self, input_ids: &Tensor) -> Result { + input_ids + .apply(&self.text_model)? + .matmul(&self.text_projection) + } + + pub fn get_image_features(&self, pixel_values: &Tensor) -> Result { + pixel_values.apply(&self.vision_model) + } + + pub fn forward(&self, pixel_values: &Tensor, input_ids: &Tensor) -> Result<(Tensor, Tensor)> { + let image_features = self.get_image_features(pixel_values)?; + let text_features = self.get_text_features(input_ids)?; + let image_features_normalized = div_l2_norm(&image_features)?; + let text_features_normalized = div_l2_norm(&text_features)?; + let logits_per_text = text_features_normalized.matmul(&image_features_normalized.t()?)?; + let logit_scale = self.logit_scale.exp()?; + let logits_per_text = logits_per_text.broadcast_mul(&logit_scale)?; + let logits_per_image = logits_per_text.t()?; + Ok((logits_per_text, logits_per_image)) + } +} + +pub fn div_l2_norm(v: &Tensor) -> Result { + let l2_norm = v.sqr()?.sum_keepdim(D::Minus1)?.sqrt()?; + v.broadcast_div(&l2_norm) +} diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index a234b8bb..9f7856ea 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -37,11 +37,13 @@ pub mod mistral; pub mod mixformer; pub mod mixtral; pub mod mmdit; +pub mod mobileclip; pub mod mobilenetv4; pub mod mobileone; pub mod moondream; pub mod mpt; pub mod olmo; +pub mod openclip; pub mod parler_tts; pub mod persimmon; pub mod phi; diff --git a/candle-transformers/src/models/openclip/mod.rs b/candle-transformers/src/models/openclip/mod.rs new file mode 100644 index 00000000..ee2a501d --- /dev/null +++ b/candle-transformers/src/models/openclip/mod.rs @@ -0,0 +1 @@ +pub mod text_model; diff --git a/candle-transformers/src/models/openclip/text_model.rs b/candle-transformers/src/models/openclip/text_model.rs new file mode 100644 index 00000000..7b444e79 --- /dev/null +++ b/candle-transformers/src/models/openclip/text_model.rs @@ -0,0 +1,266 @@ +//! Text encoder as used in most OpenCLIP pretrained models +//! https://github.com/mlfoundations/open_clip + +use candle::{DType, IndexOp, Result, Tensor, D}; +use candle_nn::{ + embedding, layer_norm, linear, ops::softmax_last_dim, Embedding, LayerNorm, Linear, Module, + VarBuilder, +}; + +#[derive(Debug, Clone)] +pub struct Config { + pub vocab_size: usize, + pub embed_dim: usize, + pub intermediate_size: usize, + pub max_position_embeddings: usize, + pub pad_with: Option, + pub num_hidden_layers: usize, + pub num_attention_heads: usize, + pub projection_dim: usize, +} + +impl Config { + pub fn vit_base_patch32() -> Self { + Self { + vocab_size: 49408, + embed_dim: 512, + intermediate_size: 2048, + max_position_embeddings: 77, + pad_with: None, + num_hidden_layers: 12, + num_attention_heads: 8, + projection_dim: 512, + } + } +} + +#[derive(Clone, Debug)] +struct TextEmbeddings { + token_embedding: Embedding, + position_embedding: Tensor, +} + +impl TextEmbeddings { + fn new(vs: VarBuilder, c: &Config) -> Result { + let token_embedding = embedding(c.vocab_size, c.embed_dim, vs.pp("token_embedding"))?; + let position_embedding = vs.get( + (c.max_position_embeddings, c.embed_dim), + "positional_embedding", + )?; + Ok(TextEmbeddings { + token_embedding, + position_embedding, + }) + } +} + +impl Module for TextEmbeddings { + fn forward(&self, input_ids: &Tensor) -> Result { + let seq_length = input_ids.dim(D::Minus1)?; + let inputs_embeds = self.token_embedding.forward(input_ids)?; + + let position_embedding = self.position_embedding.narrow(0, 0, seq_length)?; + + inputs_embeds.broadcast_add(&position_embedding) + } +} + +#[derive(Clone, Debug)] +struct Attention { + k_proj: candle_nn::Linear, + v_proj: candle_nn::Linear, + q_proj: candle_nn::Linear, + out_proj: Linear, + head_dim: usize, + scale: f64, + num_attention_heads: usize, +} + +impl Attention { + fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result { + let embed_dim = c.embed_dim; + let num_attention_heads = c.num_attention_heads; + + let in_proj_weights = vs + .get((embed_dim * 3, embed_dim), "in_proj_weight")? + .chunk(3, 0)?; + let (q_w, k_w, v_w) = ( + &in_proj_weights[0], + &in_proj_weights[1], + &in_proj_weights[2], + ); + let in_proj_biases = vs.get(embed_dim * 3, "in_proj_bias")?.chunk(3, 0)?; + let (q_b, k_b, v_b) = (&in_proj_biases[0], &in_proj_biases[1], &in_proj_biases[2]); + + let q_proj = Linear::new(q_w.clone(), Some(q_b.clone())); + let k_proj = Linear::new(k_w.clone(), Some(k_b.clone())); + let v_proj = Linear::new(v_w.clone(), Some(v_b.clone())); + let out_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("out_proj"))?; + let head_dim = embed_dim / num_attention_heads; + let scale = (head_dim as f64).powf(-0.5); + + Ok(Attention { + k_proj, + v_proj, + q_proj, + out_proj, + head_dim, + scale, + num_attention_heads, + }) + } + + fn shape_multihead(&self, xs: &Tensor, bsz: usize, seq_len: usize) -> Result { + xs.reshape((bsz, seq_len, self.num_attention_heads, self.head_dim))? + .transpose(1, 2)? + .contiguous()? + .to_dtype(DType::F32) + } + + fn forward(&self, xs: &Tensor) -> Result { + let in_dtype = xs.dtype(); + let (bsz, seq_len, embed_dim) = xs.dims3()?; + + let q = self.shape_multihead(&self.q_proj.forward(xs)?, bsz, seq_len)?; + let k = self.shape_multihead(&self.k_proj.forward(xs)?, bsz, seq_len)?; + let v = self.shape_multihead(&self.v_proj.forward(xs)?, bsz, seq_len)?; + let q = (q * self.scale)?; + + let attn_weights = q.matmul(&k.transpose(D::Minus1, D::Minus2)?)?; + + let attn_weights = softmax_last_dim(&attn_weights)?; + + let attn_output = attn_weights.matmul(&v)?.to_dtype(in_dtype)?; + let attn_output = attn_output + .transpose(1, 2)? + .contiguous()? + .reshape((bsz, seq_len, embed_dim))?; + let out = self.out_proj.forward(&attn_output)?; + Ok(out) + } +} + +#[derive(Clone, Debug)] +struct Mlp { + fc1: Linear, + fc2: Linear, +} + +impl Mlp { + fn new(vs: VarBuilder, c: &Config) -> Result { + let fc1 = linear(c.embed_dim, c.intermediate_size, vs.pp("c_fc"))?; + let fc2 = linear(c.intermediate_size, c.embed_dim, vs.pp("c_proj"))?; + + Ok(Mlp { fc1, fc2 }) + } +} + +impl Mlp { + fn forward(&self, xs: &Tensor) -> Result { + let xs = self.fc1.forward(xs)?; + self.fc2.forward(&xs.gelu_erf()?) + } +} + +#[derive(Clone, Debug)] +struct EncoderLayer { + self_attn: Attention, + layer_norm1: LayerNorm, + mlp: Mlp, + layer_norm2: LayerNorm, +} + +impl EncoderLayer { + fn new(vs: VarBuilder, c: &Config) -> Result { + let self_attn = Attention::new(vs.pp("attn"), c)?; + let layer_norm1 = layer_norm(c.embed_dim, 1e-5, vs.pp("ln_1"))?; + let mlp = Mlp::new(vs.pp("mlp"), c)?; + let layer_norm2 = layer_norm(c.embed_dim, 1e-5, vs.pp("ln_2"))?; + + Ok(EncoderLayer { + self_attn, + layer_norm1, + mlp, + layer_norm2, + }) + } + + fn forward(&self, xs: &Tensor) -> Result { + let residual = xs; + let xs = self.layer_norm1.forward(xs)?; + let xs = self.self_attn.forward(&xs)?; + let xs = (xs + residual)?; + + let residual = &xs; + let xs = self.layer_norm2.forward(&xs)?; + let xs = self.mlp.forward(&xs)?; + let out = (xs + residual)?; + Ok(out) + } +} + +#[derive(Clone, Debug)] +pub struct Encoder { + layers: Vec, +} + +impl Encoder { + pub fn new(vs: VarBuilder, c: &Config) -> Result { + let vs = vs.pp("resblocks"); + let mut layers: Vec = Vec::new(); + for index in 0..c.num_hidden_layers { + let layer = EncoderLayer::new(vs.pp(index.to_string()), c)?; + layers.push(layer) + } + Ok(Encoder { layers }) + } + + pub fn forward(&self, xs: &Tensor) -> Result { + let mut xs = xs.clone(); + for layer in self.layers.iter() { + xs = layer.forward(&xs)?; + } + Ok(xs) + } +} + +/// A text transformer as used in CLIP variants. +#[derive(Clone, Debug)] +pub struct OpenClipTextTransformer { + embeddings: TextEmbeddings, + encoder: Encoder, + final_layer_norm: LayerNorm, +} + +impl OpenClipTextTransformer { + pub fn new(vs: VarBuilder, c: &Config) -> Result { + let embeddings = TextEmbeddings::new(vs.clone(), c)?; + let final_layer_norm = layer_norm(c.embed_dim, 1e-5, vs.pp("ln_final"))?; + let encoder = Encoder::new(vs.pp("transformer"), c)?; + Ok(OpenClipTextTransformer { + embeddings, + encoder, + final_layer_norm, + }) + } + + pub fn forward(&self, input_ids: &Tensor) -> Result { + let input_ids = self.embeddings.forward(input_ids)?; + let input_ids = self.encoder.forward(&input_ids)?; + self.final_layer_norm.forward(&input_ids) + } +} + +impl Module for OpenClipTextTransformer { + fn forward(&self, input_ids: &Tensor) -> Result { + let output = self.forward(input_ids)?; + let sequence_max_indices = input_ids.argmax(D::Minus1)?.to_dtype(DType::I64)?; + + let mut indices = Vec::new(); + for (batch_idx, &seq_idx) in sequence_max_indices.to_vec1::()?.iter().enumerate() { + let index = output.i((batch_idx, seq_idx as usize))?.unsqueeze(0)?; + indices.push(index); + } + Tensor::cat(&indices, 0) + } +}