From b0340d72ec9dd8f3bb1778e5a7d73111e67a4393 Mon Sep 17 00:00:00 2001 From: Tigran Zhampeissov <81493298+Tigranchick@users.noreply.github.com> Date: Thu, 28 Mar 2024 17:44:12 +0500 Subject: [PATCH] CLIP model implementation with example (#1950) * CLIP model implementation with example * CLIP Implementation fixes, batch images * CLIP model remove images from git * CLIP model remove unnecessary use of batch_indices --- candle-examples/examples/clip/README.md | 46 +++ candle-examples/examples/clip/main.rs | 202 ++++++++++ candle-transformers/src/models/clip/mod.rs | 167 ++++++++ .../src/models/clip/text_model.rs | 355 ++++++++++++++++++ .../src/models/clip/vision_model.rs | 171 +++++++++ candle-transformers/src/models/mod.rs | 1 + 6 files changed, 942 insertions(+) create mode 100644 candle-examples/examples/clip/README.md create mode 100644 candle-examples/examples/clip/main.rs create mode 100644 candle-transformers/src/models/clip/mod.rs create mode 100644 candle-transformers/src/models/clip/text_model.rs create mode 100644 candle-transformers/src/models/clip/vision_model.rs diff --git a/candle-examples/examples/clip/README.md b/candle-examples/examples/clip/README.md new file mode 100644 index 00000000..f0ee3b2c --- /dev/null +++ b/candle-examples/examples/clip/README.md @@ -0,0 +1,46 @@ +Contrastive Language-Image Pre-Training + +Contrastive Language-Image Pre-Training (CLIP) is an architecture trained on +pairs of images with related texts. + +https://github.com/openai/CLIP + +https://github.com/huggingface/transformers/tree/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip + +## Running on an example on cpu + +``` +$ cargo run --example clip --release -- --images "candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg","candle-examples/examples/yolo-v8/assets/bike.jpg" --cpu --sequences "a cycling race","a photo of two cats","a robot holding a candle" + + +Results for image: candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg + +INFO clip: Probability: 0.0000% Text: a cycling race +INFO clip: Probability: 0.0000% Text: a photo of two cats +INFO clip: Probability: 100.0000% Text: a robot holding a candle + +Results for image: candle-examples/examples/yolo-v8/assets/bike.jpg + +INFO clip: Probability: 99.9999% Text: a cycling race +INFO clip: Probability: 0.0001% Text: a photo of two cats +INFO clip: Probability: 0.0000% Text: a robot holding a candle +``` + +## Running on an example with metal feature (mac) + +``` +$ cargo run --features metal --example clip --release -- --images "candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg","candle-examples/examples/yolo-v8/assets/bike.jpg" --cpu --sequences "a cycling race","a photo of two cats","a robot holding a candle" + + +Results for image: candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg + +INFO clip: Probability: 0.0000% Text: a cycling race +INFO clip: Probability: 0.0000% Text: a photo of two cats +INFO clip: Probability: 100.0000% Text: a robot holding a candle + +Results for image: candle-examples/examples/yolo-v8/assets/bike.jpg + +INFO clip: Probability: 99.9999% Text: a cycling race +INFO clip: Probability: 0.0001% Text: a photo of two cats +INFO clip: Probability: 0.0000% Text: a robot holding a candle +``` diff --git a/candle-examples/examples/clip/main.rs b/candle-examples/examples/clip/main.rs new file mode 100644 index 00000000..f301d211 --- /dev/null +++ b/candle-examples/examples/clip/main.rs @@ -0,0 +1,202 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use anyhow::Error as E; +use clap::Parser; + +use candle::{DType, Device, Tensor}; +use candle_nn::{ops::softmax, VarBuilder}; +use candle_transformers::models::clip; + +use tokenizers::Tokenizer; +use tracing::info; + +#[derive(Parser)] +struct Args { + #[arg(long)] + model: Option, + + #[arg(long)] + tokenizer: Option, + + #[arg(long, use_value_delimiter = true)] + images: Option>, + + #[arg(long)] + cpu: bool, + + #[arg(long, use_value_delimiter = true)] + sequences: Option>, +} + +fn load_image>(path: T, image_size: usize) -> anyhow::Result { + let img = image::io::Reader::open(path)?.decode()?; + let (height, width) = (image_size, image_size); + let img = img.resize_to_fill( + width as u32, + height as u32, + image::imageops::FilterType::Triangle, + ); + + let img = img.to_rgb8(); + + let img = img.into_raw(); + let img = Tensor::from_vec(img, (height, width, 3), &Device::Cpu)? + .permute((2, 0, 1))? + .to_dtype(DType::F32)? + .affine(2. / 255., -1.)?; + // .unsqueeze(0)?; + Ok(img) +} + +fn load_images>( + paths: &Vec, + image_size: usize, +) -> anyhow::Result { + let mut images = vec![]; + + for path in paths { + let tensor = load_image(path, image_size)?; + images.push(tensor); + } + + let images = Tensor::stack(&images, 0)?; + + Ok(images) +} + +pub fn main() -> anyhow::Result<()> { + // std::env::set_var("RUST_BACKTRACE", "full"); + + let args = Args::parse(); + + tracing_subscriber::fmt::init(); + + let model_file = match args.model { + None => { + let api = hf_hub::api::sync::Api::new()?; + + let api = api.repo(hf_hub::Repo::with_revision( + "openai/clip-vit-base-patch32".to_string(), + hf_hub::RepoType::Model, + "refs/pr/15".to_string(), + )); + + api.get("model.safetensors")? + } + Some(model) => model.into(), + }; + + let tokenizer = get_tokenizer(args.tokenizer)?; + + let config = clip::ClipConfig::vit_base_patch32(); + + let device = candle_examples::device(args.cpu)?; + + let vec_imgs = match args.images { + Some(imgs) => imgs, + None => vec![ + "candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg".to_string(), + "candle-examples/examples/yolo-v8/assets/bike.jpg".to_string(), + ], + }; + + // let image = load_image(args.image, config.image_size)?.to_device(&device)?; + let images = load_images(&vec_imgs, config.image_size)?.to_device(&device)?; + + let vb = + unsafe { VarBuilder::from_mmaped_safetensors(&[model_file.clone()], DType::F32, &device)? }; + + let model = clip::ClipModel::new(vb, &config)?; + + let (input_ids, vec_seq) = tokenize_sequences(args.sequences, &tokenizer, &device)?; + + let (_logits_per_text, logits_per_image) = model.forward(&images, &input_ids)?; + + let softmax_image = softmax(&logits_per_image, 1)?; + + let softmax_image_vec = softmax_image.flatten_all()?.to_vec1::()?; + + info!("softmax_image_vec: {:?}", softmax_image_vec); + + let probability_vec = softmax_image_vec + .iter() + .map(|v| v * 100.0) + .collect::>(); + + let probability_per_image = probability_vec.len() / vec_imgs.len(); + + for (i, img) in vec_imgs.iter().enumerate() { + let start = i * probability_per_image; + let end = start + probability_per_image; + let prob = &probability_vec[start..end]; + info!("\n\nResults for image: {}\n", img); + + for (i, p) in prob.iter().enumerate() { + info!("Probability: {:.4}% Text: {} ", p, vec_seq[i]); + } + } + + Ok(()) +} + +pub fn get_tokenizer(tokenizer: Option) -> anyhow::Result { + let tokenizer = match tokenizer { + None => { + let api = hf_hub::api::sync::Api::new()?; + let api = api.repo(hf_hub::Repo::with_revision( + "openai/clip-vit-base-patch32".to_string(), + hf_hub::RepoType::Model, + "refs/pr/15".to_string(), + )); + api.get("tokenizer.json")? + } + Some(file) => file.into(), + }; + + Tokenizer::from_file(tokenizer).map_err(E::msg) +} + +pub fn tokenize_sequences( + sequences: Option>, + tokenizer: &Tokenizer, + device: &Device, +) -> anyhow::Result<(Tensor, Vec)> { + let pad_id = *tokenizer + .get_vocab(true) + .get("<|endoftext|>") + .ok_or(E::msg("No pad token"))?; + + let vec_seq = match sequences { + Some(seq) => seq, + None => vec![ + "a cycling race".to_string(), + "a photo of two cats".to_string(), + "a robot holding a candle".to_string(), + ], + }; + + let mut tokens = vec![]; + + for seq in vec_seq.clone() { + let encoding = tokenizer.encode(seq, true).map_err(E::msg)?; + tokens.push(encoding.get_ids().to_vec()); + } + + let max_len = tokens.iter().map(|v| v.len()).max().unwrap_or(0); + + // Pad the sequences to have the same length + for token_vec in tokens.iter_mut() { + let len_diff = max_len - token_vec.len(); + if len_diff > 0 { + token_vec.extend(vec![pad_id; len_diff]); + } + } + + let input_ids = Tensor::new(tokens, device)?; + + Ok((input_ids, vec_seq)) +} diff --git a/candle-transformers/src/models/clip/mod.rs b/candle-transformers/src/models/clip/mod.rs new file mode 100644 index 00000000..02df782b --- /dev/null +++ b/candle-transformers/src/models/clip/mod.rs @@ -0,0 +1,167 @@ +//! Contrastive Language-Image Pre-Training +//! +//! Contrastive Language-Image Pre-Training (CLIP) is an architecture trained on +//! pairs of images with related texts. +//! +//! https://github.com/openai/CLIP +//! https://github.com/huggingface/transformers/tree/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip +use self::{ + text_model::{Activation, ClipTextTransformer}, + vision_model::ClipVisionTransformer, +}; +use candle::{Result, Tensor, D}; +use candle_nn::Module; + +use tracing::warn; + +pub mod text_model; +pub mod vision_model; + +pub struct ClipModel { + text_model: ClipTextTransformer, + vision_model: ClipVisionTransformer, + visual_projection: candle_nn::Linear, + text_projection: candle_nn::Linear, + logit_scale: Tensor, +} + +pub enum EncoderConfig { + Text(text_model::ClipTextConfig), + Vision(vision_model::ClipVisionConfig), +} + +impl EncoderConfig { + pub fn embed_dim(&self) -> usize { + match self { + Self::Text(c) => c.embed_dim, + Self::Vision(c) => c.embed_dim, + } + } + + pub fn num_attention_heads(&self) -> usize { + match self { + Self::Text(c) => c.num_attention_heads, + Self::Vision(c) => c.num_attention_heads, + } + } + + pub fn intermediate_size(&self) -> usize { + match self { + Self::Text(c) => c.intermediate_size, + Self::Vision(c) => c.intermediate_size, + } + } + + pub fn num_hidden_layers(&self) -> usize { + match self { + Self::Text(c) => c.num_hidden_layers, + Self::Vision(c) => c.num_hidden_layers, + } + } + + pub fn activation(&self) -> Activation { + match self { + Self::Text(_c) => Activation::QuickGelu, + Self::Vision(c) => c.activation, + } + } +} + +pub struct ClipConfig { + pub text_config: text_model::ClipTextConfig, + pub vision_config: vision_model::ClipVisionConfig, + pub logit_scale_init_value: f32, + pub image_size: usize, +} + +impl ClipConfig { + // base image size is 224, model size is 600Mb + pub fn vit_base_patch32() -> Self { + let text_config = text_model::ClipTextConfig::vit_base_patch32(); + let vision_config = vision_model::ClipVisionConfig::vit_base_patch32(); + + Self { + text_config, + vision_config, + logit_scale_init_value: 2.6592, + image_size: 224, + } + } +} + +impl ClipModel { + pub fn new(vs: candle_nn::VarBuilder, c: &ClipConfig) -> Result { + let text_model = ClipTextTransformer::new(vs.pp("text_model"), &c.text_config)?; + + let vision_model = ClipVisionTransformer::new(vs.pp("vision_model"), &c.vision_config)?; + + let visual_projection = candle_nn::linear_no_bias( + c.vision_config.embed_dim, + c.vision_config.projection_dim, + vs.pp("visual_projection"), + )?; + + let text_projection = candle_nn::linear_no_bias( + c.text_config.embed_dim, + c.text_config.projection_dim, + vs.pp("text_projection"), + )?; + + // originally nn.Parameter + let logit_scale = if vs.contains_tensor("logit_scale") { + vs.get(&[], "logit_scale")? + } else { + warn!("Creating logit_scale tensor, results may vary."); + Tensor::new(&[c.logit_scale_init_value], vs.device())? + }; + + Ok(Self { + text_model, + vision_model, + visual_projection, + text_projection, + logit_scale, + }) + } + + pub fn get_text_features(&self, input_ids: &Tensor) -> Result { + let text_outputs = self.text_model.forward(input_ids)?; + + let text_features = self.text_projection.forward(&text_outputs)?; + + Ok(text_features) + } + + pub fn get_image_features(&self, pixel_values: &Tensor) -> Result { + let image_features = self.vision_model.forward(pixel_values)?; + + let image_features = self.visual_projection.forward(&image_features)?; + + Ok(image_features) + } + + pub fn forward(&self, pixel_values: &Tensor, input_ids: &Tensor) -> Result<(Tensor, Tensor)> { + let image_features = self.get_image_features(pixel_values)?; + + let text_features = self.get_text_features(input_ids)?; + + let image_features_normalized = div_l2_norm(&image_features)?; + + let text_features_normalized = div_l2_norm(&text_features)?; + + let logits_per_text = text_features_normalized.matmul(&image_features_normalized.t()?)?; + + let logit_scale = &self.logit_scale.exp()?; + + let logits_per_text = logits_per_text.broadcast_mul(&logit_scale)?; + + let logits_per_image = logits_per_text.t()?; + + Ok((logits_per_text, logits_per_image)) + } +} + +pub fn div_l2_norm(v: &Tensor) -> Result { + let l2_norm = v.sqr()?.sum_keepdim(D::Minus1)?.sqrt()?; + v.broadcast_div(&l2_norm) +} diff --git a/candle-transformers/src/models/clip/text_model.rs b/candle-transformers/src/models/clip/text_model.rs new file mode 100644 index 00000000..852d3e24 --- /dev/null +++ b/candle-transformers/src/models/clip/text_model.rs @@ -0,0 +1,355 @@ +//! Contrastive Language-Image Pre-Training +//! +//! Contrastive Language-Image Pre-Training (CLIP) is an architecture trained on +//! pairs of images with related texts. +//! +//! https://github.com/openai/CLIP +//! https://github.com/huggingface/transformers/tree/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip + +use candle::{DType, Device, IndexOp, Result, Tensor, D}; +use candle_nn as nn; +use candle_nn::Module; + +use super::EncoderConfig; + +#[derive(Debug, Clone, Copy)] +pub enum Activation { + QuickGelu, +} + +impl Module for Activation { + fn forward(&self, xs: &Tensor) -> Result { + match self { + Activation::QuickGelu => xs * nn::ops::sigmoid(&(xs * 1.702f64)?)?, + } + } +} + +#[derive(Debug, Clone)] +pub struct ClipTextConfig { + pub vocab_size: usize, + pub embed_dim: usize, + pub activation: Activation, + pub intermediate_size: usize, + pub max_position_embeddings: usize, + pub pad_with: Option, + pub num_hidden_layers: usize, + pub num_attention_heads: usize, + #[allow(dead_code)] + pub projection_dim: usize, +} + +impl ClipTextConfig { + // The config details can be found in the "text_config" section of this json file: + // https://huggingface.co/openai/clip-vit-large-patch14/blob/main/config.json + pub fn vit_base_patch32() -> Self { + Self { + vocab_size: 49408, + embed_dim: 512, + intermediate_size: 2048, + max_position_embeddings: 77, + pad_with: None, + num_hidden_layers: 12, + num_attention_heads: 8, + projection_dim: 512, + activation: Activation::QuickGelu, + } + } +} + +// ClipTextEmbeddings mostly based on the existing implementation in the stable diffision model. +// TODO rewrite to be more similar to https://github.com/huggingface/transformers/blob/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip/modeling_clip.py#L142 +#[derive(Debug)] +struct ClipTextEmbeddings { + token_embedding: candle_nn::Embedding, + position_embedding: candle_nn::Embedding, + position_ids: Tensor, +} + +impl ClipTextEmbeddings { + fn new(vs: candle_nn::VarBuilder, c: &ClipTextConfig) -> Result { + let token_embedding = + candle_nn::embedding(c.vocab_size, c.embed_dim, vs.pp("token_embedding"))?; + + let position_embedding: nn::Embedding = candle_nn::embedding( + c.max_position_embeddings, + c.embed_dim, + vs.pp("position_embedding"), + )?; + + let position_ids = + Tensor::arange(0u32, c.max_position_embeddings as u32, vs.device())?.unsqueeze(0)?; + + Ok(ClipTextEmbeddings { + token_embedding, + position_embedding, + position_ids, + }) + } +} + +impl Module for ClipTextEmbeddings { + fn forward(&self, input_ids: &Tensor) -> Result { + let seq_length = input_ids.dim(D::Minus1)?; + + let inputs_embeds = &self.token_embedding.forward(input_ids)?; + + let postion_ids = &self.position_ids.narrow(1, 0, seq_length)?; + + let position_embedding = &self.position_embedding.forward(&postion_ids)?; + + let inputs_embeds = inputs_embeds.broadcast_add(&position_embedding)?; + + Ok(inputs_embeds) + } +} + +#[derive(Debug)] +struct ClipAttention { + k_proj: candle_nn::Linear, + v_proj: candle_nn::Linear, + q_proj: candle_nn::Linear, + out_proj: candle_nn::Linear, + head_dim: usize, + scale: f64, + num_attention_heads: usize, +} + +impl ClipAttention { + fn new(vs: candle_nn::VarBuilder, c: &EncoderConfig) -> Result { + let embed_dim = c.embed_dim(); + let num_attention_heads = c.num_attention_heads(); + let k_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("k_proj"))?; + let v_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("v_proj"))?; + let q_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("q_proj"))?; + let out_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("out_proj"))?; + let head_dim = embed_dim / num_attention_heads; + let scale = (head_dim as f64).powf(-0.5); + + Ok(ClipAttention { + k_proj, + v_proj, + q_proj, + out_proj, + head_dim, + scale, + num_attention_heads, + }) + } + + fn shape(&self, xs: &Tensor, seq_len: usize, bsz: usize) -> Result { + xs.reshape((bsz, seq_len, self.num_attention_heads, self.head_dim))? + .transpose(1, 2)? + .contiguous() + } + + fn forward(&self, xs: &Tensor, causal_attention_mask: Option<&Tensor>) -> Result { + let in_dtype = xs.dtype(); + let (bsz, seq_len, embed_dim) = xs.dims3()?; + + let query_states = (self.q_proj.forward(xs)? * self.scale)?; + let proj_shape = (bsz * self.num_attention_heads, seq_len, self.head_dim); + let query_states = self + .shape(&query_states, seq_len, bsz)? + .reshape(proj_shape)? + .to_dtype(DType::F32)?; + let key_states = self + .shape(&self.k_proj.forward(xs)?, seq_len, bsz)? + .reshape(proj_shape)? + .to_dtype(DType::F32)?; + let value_states = self + .shape(&self.v_proj.forward(xs)?, seq_len, bsz)? + .reshape(proj_shape)? + .to_dtype(DType::F32)?; + let attn_weights = query_states.matmul(&key_states.transpose(1, 2)?)?; + + let src_len = key_states.dim(1)?; + + let attn_weights = if let Some(causal_attention_mask) = causal_attention_mask { + let attn_reshape = + attn_weights.reshape((bsz, self.num_attention_heads, seq_len, src_len))?; + + let attn_weights = attn_reshape.broadcast_add(causal_attention_mask)?; + + let attn_weights = + attn_weights.reshape((bsz * self.num_attention_heads, seq_len, src_len))?; + + attn_weights + } else { + attn_weights + }; + + let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?; + + let attn_output = attn_weights.matmul(&value_states)?.to_dtype(in_dtype)?; + let attn_output = attn_output + .reshape((bsz, self.num_attention_heads, seq_len, self.head_dim))? + .transpose(1, 2)? + .reshape((bsz, seq_len, embed_dim))?; + self.out_proj.forward(&attn_output) + } +} + +#[derive(Debug)] +struct ClipMlp { + fc1: candle_nn::Linear, + fc2: candle_nn::Linear, + activation: Activation, +} + +impl ClipMlp { + fn new(vs: candle_nn::VarBuilder, c: &EncoderConfig) -> Result { + let fc1 = candle_nn::linear(c.embed_dim(), c.intermediate_size(), vs.pp("fc1"))?; + let fc2 = candle_nn::linear(c.intermediate_size(), c.embed_dim(), vs.pp("fc2"))?; + + Ok(ClipMlp { + fc1, + fc2, + activation: c.activation(), + }) + } +} + +impl ClipMlp { + fn forward(&self, xs: &Tensor) -> Result { + let xs = self.fc1.forward(xs)?; + self.fc2.forward(&self.activation.forward(&xs)?) + } +} + +#[derive(Debug)] +struct ClipEncoderLayer { + self_attn: ClipAttention, + layer_norm1: candle_nn::LayerNorm, + mlp: ClipMlp, + layer_norm2: candle_nn::LayerNorm, +} + +impl ClipEncoderLayer { + fn new(vs: candle_nn::VarBuilder, c: &EncoderConfig) -> Result { + let self_attn = ClipAttention::new(vs.pp("self_attn"), c)?; + let layer_norm1 = candle_nn::layer_norm(c.embed_dim(), 1e-5, vs.pp("layer_norm1"))?; + let mlp = ClipMlp::new(vs.pp("mlp"), c)?; + let layer_norm2 = candle_nn::layer_norm(c.embed_dim(), 1e-5, vs.pp("layer_norm2"))?; + + Ok(ClipEncoderLayer { + self_attn, + layer_norm1, + mlp, + layer_norm2, + }) + } + + fn forward(&self, xs: &Tensor, causal_attention_mask: Option<&Tensor>) -> Result { + let residual = xs; + let xs = self.layer_norm1.forward(xs)?; + let xs = self.self_attn.forward(&xs, causal_attention_mask)?; + let xs = (xs + residual)?; + + let residual = &xs; + let xs = self.layer_norm2.forward(&xs)?; + let xs = self.mlp.forward(&xs)?; + xs + residual + } +} + +#[derive(Debug)] +pub struct ClipEncoder { + layers: Vec, +} + +impl ClipEncoder { + pub fn new(vs: candle_nn::VarBuilder, c: &EncoderConfig) -> Result { + let vs = vs.pp("layers"); + let mut layers: Vec = Vec::new(); + for index in 0..c.num_hidden_layers() { + let layer = ClipEncoderLayer::new(vs.pp(&index.to_string()), c)?; + layers.push(layer) + } + Ok(ClipEncoder { layers }) + } + + pub fn forward(&self, xs: &Tensor, causal_attention_mask: Option<&Tensor>) -> Result { + let mut xs = xs.clone(); + + for layer in self.layers.iter() { + xs = layer.forward(&xs, causal_attention_mask)?; + } + Ok(xs) + } +} + +/// A CLIP transformer based model. +#[derive(Debug)] +pub struct ClipTextTransformer { + embeddings: ClipTextEmbeddings, + encoder: ClipEncoder, + final_layer_norm: candle_nn::LayerNorm, +} + +impl ClipTextTransformer { + pub fn new(vs: candle_nn::VarBuilder, c: &ClipTextConfig) -> Result { + let embeddings = ClipTextEmbeddings::new(vs.pp("embeddings"), c)?; + let encoder = ClipEncoder::new(vs.pp("encoder"), &EncoderConfig::Text(c.clone()))?; + let final_layer_norm = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("final_layer_norm"))?; + + Ok(ClipTextTransformer { + embeddings, + encoder, + final_layer_norm, + }) + } + + // TODO: rewrrite to newer version + fn build_causal_attention_mask( + bsz: usize, + seq_len: usize, + mask_after: usize, + device: &Device, + ) -> Result { + let mask: Vec<_> = (0..seq_len) + .flat_map(|i| { + (0..seq_len).map(move |j| { + if j > i || j > mask_after { + f32::MIN + } else { + 0. + } + }) + }) + .collect(); + let mask = Tensor::from_slice(&mask, (seq_len, seq_len), device)?; + mask.broadcast_as((bsz, 1, seq_len, seq_len)) + } + + pub fn forward_with_mask(&self, input_ids: &Tensor, mask_after: usize) -> Result { + let (bsz, seq_len) = input_ids.dims2()?; + let input_ids = self.embeddings.forward(input_ids)?; + + let causal_attention_mask = + Self::build_causal_attention_mask(bsz, seq_len, mask_after, input_ids.device())?; + let input_ids = self + .encoder + .forward(&input_ids, Some(&causal_attention_mask))?; + self.final_layer_norm.forward(&input_ids) + } +} + +impl Module for ClipTextTransformer { + fn forward(&self, input_ids: &Tensor) -> Result { + let output = self.forward_with_mask(input_ids, usize::MAX)?; + + let sequence_max_indices = input_ids.argmax(D::Minus1)?.to_dtype(DType::I64)?; + + let mut indices: Vec = Vec::new(); + + for (batch_idx, &seq_idx) in sequence_max_indices.to_vec1::()?.iter().enumerate() { + let index = output.i((batch_idx, seq_idx as usize))?.unsqueeze(0)?; + indices.push(index); + } + + let pooled_output = Tensor::cat(&indices, 0)?; + + Ok(pooled_output) + } +} diff --git a/candle-transformers/src/models/clip/vision_model.rs b/candle-transformers/src/models/clip/vision_model.rs new file mode 100644 index 00000000..af9af7ae --- /dev/null +++ b/candle-transformers/src/models/clip/vision_model.rs @@ -0,0 +1,171 @@ +//! Contrastive Language-Image Pre-Training +//! +//! Contrastive Language-Image Pre-Training (CLIP) is an architecture trained on +//! pairs of images with related texts. +//! +//! https://github.com/openai/CLIP +//! https://github.com/huggingface/transformers/tree/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip + +use candle::{IndexOp, Result, Shape, Tensor, D}; +use candle_nn as nn; +use candle_nn::Module; +use nn::Conv2dConfig; +use tracing::warn; + +use super::{ + text_model::{Activation, ClipEncoder}, + EncoderConfig, +}; + +#[derive(Debug, Clone)] +pub struct ClipVisionConfig { + pub embed_dim: usize, + pub activation: Activation, + pub intermediate_size: usize, + pub num_hidden_layers: usize, + pub num_attention_heads: usize, + #[allow(dead_code)] + pub projection_dim: usize, + pub num_channels: usize, + pub image_size: usize, + pub patch_size: usize, +} + +impl ClipVisionConfig { + // The config details can be found in the "vision_config" section of this json file: + // https://huggingface.co/openai/clip-vit-large-patch14/blob/main/config.json + pub fn vit_base_patch32() -> Self { + Self { + embed_dim: 768, + activation: Activation::QuickGelu, + intermediate_size: 3072, + num_hidden_layers: 12, + num_attention_heads: 12, + projection_dim: 512, + num_channels: 3, + image_size: 224, + patch_size: 32, + } + } +} + +// https://github.com/huggingface/transformers/blob/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip/modeling_clip.py#L112 +#[derive(Debug)] +struct ClipVisionEmbeddings { + patch_embedding: candle_nn::Conv2d, + position_ids: Tensor, + class_embedding: Tensor, + position_embedding: candle_nn::Embedding, +} + +impl ClipVisionEmbeddings { + fn new(vs: candle_nn::VarBuilder, c: &ClipVisionConfig) -> Result { + // originally nn.Parameter + let class_embedding = if vs.contains_tensor("class_embedding") { + vs.get(c.embed_dim, "class_embedding")? + } else { + warn!("class_embedding not found in the. Initializing a new one."); + Tensor::randn(0.0 as f32, 1.0 as f32, &[c.embed_dim], vs.device())? + }; + + let num_patches = (c.image_size / c.patch_size).pow(2); + + let num_positions = num_patches + 1; + + let position_ids = Tensor::arange(0, num_positions as i64, vs.device())?; + + let conv2dconfig = Conv2dConfig { + stride: c.patch_size, + ..Default::default() + }; + let position_embedding = + candle_nn::embedding(num_positions, c.embed_dim, vs.pp("position_embedding"))?; + + let patch_embedding = candle_nn::conv2d_no_bias( + c.num_channels, + c.embed_dim, + c.patch_size, + conv2dconfig, + vs.pp("patch_embedding"), + )?; + + Ok(Self { + patch_embedding, + position_ids, + class_embedding, + position_embedding, + }) + } +} + +impl Module for ClipVisionEmbeddings { + fn forward(&self, pixel_values: &Tensor) -> Result { + let batch_size = pixel_values.shape().dims(); + + let patch_embeds = self.patch_embedding.forward(&pixel_values)?; + + let patch_embeds = patch_embeds.flatten_from(2)?; + + let patch_embeds = patch_embeds.transpose(1, 2)?; + + let class_embedding = self.class_embedding.clone(); + + let shape = Shape::from(vec![batch_size[0], 1, class_embedding.dim(D::Minus1)?]); + + let class_embeds = class_embedding.expand(shape)?; + + let embeddings = Tensor::cat(&[class_embeds, patch_embeds], 1)?; + + let position_embedding = self.position_embedding.forward(&self.position_ids)?; + + let embeddings = embeddings.broadcast_add(&position_embedding)?; + + Ok(embeddings) + } +} + +// https://github.com/huggingface/transformers/blob/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip/modeling_clip.py#L743 +#[derive(Debug)] +pub struct ClipVisionTransformer { + embeddings: ClipVisionEmbeddings, + encoder: ClipEncoder, + pre_layer_norm: candle_nn::LayerNorm, + final_layer_norm: candle_nn::LayerNorm, +} + +impl ClipVisionTransformer { + pub fn new(vs: candle_nn::VarBuilder, c: &ClipVisionConfig) -> Result { + let embeddings = ClipVisionEmbeddings::new(vs.pp("embeddings"), c)?; + + let pre_layer_norm = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("pre_layrnorm"))?; + + let encoder = ClipEncoder::new(vs.pp("encoder"), &EncoderConfig::Vision(c.clone()))?; + + let final_layer_norm = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("post_layernorm"))?; + + Ok(Self { + embeddings, + encoder, + final_layer_norm, + pre_layer_norm, + }) + } +} + +impl Module for ClipVisionTransformer { + fn forward(&self, pixel_values: &Tensor) -> Result { + let hidden_states = self.embeddings.forward(pixel_values)?; + + let hidden_states = self.pre_layer_norm.forward(&hidden_states)?; + + let encoder_outputs = self.encoder.forward(&hidden_states, None)?; + + // https://github.com/huggingface/transformers/blob/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip/modeling_clip.py#L787 + // pooled_output = encoder_outputs[:, 0, :] + let pooled_output = encoder_outputs.i((.., 0, ..))?; + + let output = self.final_layer_norm.forward(&pooled_output)?; + + Ok(output) + } +} diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index 389d1a80..4267059c 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -12,6 +12,7 @@ pub mod efficientvit; pub mod encodec; pub mod falcon; pub mod gemma; +pub mod clip; pub mod jina_bert; pub mod llama; pub mod llama2_c;