mirror of
https://github.com/huggingface/candle.git
synced 2025-06-22 12:28:06 +00:00
CLIP model implementation with example (#1950)
* CLIP model implementation with example * CLIP Implementation fixes, batch images * CLIP model remove images from git * CLIP model remove unnecessary use of batch_indices
This commit is contained in:

committed by
GitHub

parent
b3484e7a5e
commit
b0340d72ec
167
candle-transformers/src/models/clip/mod.rs
Normal file
167
candle-transformers/src/models/clip/mod.rs
Normal file
@ -0,0 +1,167 @@
|
||||
//! Contrastive Language-Image Pre-Training
|
||||
//!
|
||||
//! Contrastive Language-Image Pre-Training (CLIP) is an architecture trained on
|
||||
//! pairs of images with related texts.
|
||||
//!
|
||||
//! https://github.com/openai/CLIP
|
||||
//! https://github.com/huggingface/transformers/tree/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip
|
||||
use self::{
|
||||
text_model::{Activation, ClipTextTransformer},
|
||||
vision_model::ClipVisionTransformer,
|
||||
};
|
||||
use candle::{Result, Tensor, D};
|
||||
use candle_nn::Module;
|
||||
|
||||
use tracing::warn;
|
||||
|
||||
pub mod text_model;
|
||||
pub mod vision_model;
|
||||
|
||||
pub struct ClipModel {
|
||||
text_model: ClipTextTransformer,
|
||||
vision_model: ClipVisionTransformer,
|
||||
visual_projection: candle_nn::Linear,
|
||||
text_projection: candle_nn::Linear,
|
||||
logit_scale: Tensor,
|
||||
}
|
||||
|
||||
pub enum EncoderConfig {
|
||||
Text(text_model::ClipTextConfig),
|
||||
Vision(vision_model::ClipVisionConfig),
|
||||
}
|
||||
|
||||
impl EncoderConfig {
|
||||
pub fn embed_dim(&self) -> usize {
|
||||
match self {
|
||||
Self::Text(c) => c.embed_dim,
|
||||
Self::Vision(c) => c.embed_dim,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn num_attention_heads(&self) -> usize {
|
||||
match self {
|
||||
Self::Text(c) => c.num_attention_heads,
|
||||
Self::Vision(c) => c.num_attention_heads,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn intermediate_size(&self) -> usize {
|
||||
match self {
|
||||
Self::Text(c) => c.intermediate_size,
|
||||
Self::Vision(c) => c.intermediate_size,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn num_hidden_layers(&self) -> usize {
|
||||
match self {
|
||||
Self::Text(c) => c.num_hidden_layers,
|
||||
Self::Vision(c) => c.num_hidden_layers,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn activation(&self) -> Activation {
|
||||
match self {
|
||||
Self::Text(_c) => Activation::QuickGelu,
|
||||
Self::Vision(c) => c.activation,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ClipConfig {
|
||||
pub text_config: text_model::ClipTextConfig,
|
||||
pub vision_config: vision_model::ClipVisionConfig,
|
||||
pub logit_scale_init_value: f32,
|
||||
pub image_size: usize,
|
||||
}
|
||||
|
||||
impl ClipConfig {
|
||||
// base image size is 224, model size is 600Mb
|
||||
pub fn vit_base_patch32() -> Self {
|
||||
let text_config = text_model::ClipTextConfig::vit_base_patch32();
|
||||
let vision_config = vision_model::ClipVisionConfig::vit_base_patch32();
|
||||
|
||||
Self {
|
||||
text_config,
|
||||
vision_config,
|
||||
logit_scale_init_value: 2.6592,
|
||||
image_size: 224,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ClipModel {
|
||||
pub fn new(vs: candle_nn::VarBuilder, c: &ClipConfig) -> Result<Self> {
|
||||
let text_model = ClipTextTransformer::new(vs.pp("text_model"), &c.text_config)?;
|
||||
|
||||
let vision_model = ClipVisionTransformer::new(vs.pp("vision_model"), &c.vision_config)?;
|
||||
|
||||
let visual_projection = candle_nn::linear_no_bias(
|
||||
c.vision_config.embed_dim,
|
||||
c.vision_config.projection_dim,
|
||||
vs.pp("visual_projection"),
|
||||
)?;
|
||||
|
||||
let text_projection = candle_nn::linear_no_bias(
|
||||
c.text_config.embed_dim,
|
||||
c.text_config.projection_dim,
|
||||
vs.pp("text_projection"),
|
||||
)?;
|
||||
|
||||
// originally nn.Parameter
|
||||
let logit_scale = if vs.contains_tensor("logit_scale") {
|
||||
vs.get(&[], "logit_scale")?
|
||||
} else {
|
||||
warn!("Creating logit_scale tensor, results may vary.");
|
||||
Tensor::new(&[c.logit_scale_init_value], vs.device())?
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
text_model,
|
||||
vision_model,
|
||||
visual_projection,
|
||||
text_projection,
|
||||
logit_scale,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn get_text_features(&self, input_ids: &Tensor) -> Result<Tensor> {
|
||||
let text_outputs = self.text_model.forward(input_ids)?;
|
||||
|
||||
let text_features = self.text_projection.forward(&text_outputs)?;
|
||||
|
||||
Ok(text_features)
|
||||
}
|
||||
|
||||
pub fn get_image_features(&self, pixel_values: &Tensor) -> Result<Tensor> {
|
||||
let image_features = self.vision_model.forward(pixel_values)?;
|
||||
|
||||
let image_features = self.visual_projection.forward(&image_features)?;
|
||||
|
||||
Ok(image_features)
|
||||
}
|
||||
|
||||
pub fn forward(&self, pixel_values: &Tensor, input_ids: &Tensor) -> Result<(Tensor, Tensor)> {
|
||||
let image_features = self.get_image_features(pixel_values)?;
|
||||
|
||||
let text_features = self.get_text_features(input_ids)?;
|
||||
|
||||
let image_features_normalized = div_l2_norm(&image_features)?;
|
||||
|
||||
let text_features_normalized = div_l2_norm(&text_features)?;
|
||||
|
||||
let logits_per_text = text_features_normalized.matmul(&image_features_normalized.t()?)?;
|
||||
|
||||
let logit_scale = &self.logit_scale.exp()?;
|
||||
|
||||
let logits_per_text = logits_per_text.broadcast_mul(&logit_scale)?;
|
||||
|
||||
let logits_per_image = logits_per_text.t()?;
|
||||
|
||||
Ok((logits_per_text, logits_per_image))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn div_l2_norm(v: &Tensor) -> Result<Tensor> {
|
||||
let l2_norm = v.sqr()?.sum_keepdim(D::Minus1)?.sqrt()?;
|
||||
v.broadcast_div(&l2_norm)
|
||||
}
|
Reference in New Issue
Block a user