feat: intergrate chinese clip and add example (#2555)

* start to impl chinese clip

* impl vision model

* copy code from bert

* refactor use

* refactor use again

* fix text model

* refactor

* try to fix text model

* tuning

* tuning chinese clip

* delete useless code

* revert code

* Clippy fixes.

* Also apply cargo fmt.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
This commit is contained in:
SethWen
2024-10-10 21:18:55 +08:00
committed by GitHub
parent 937e8eda74
commit 0d96ec31e8
5 changed files with 1358 additions and 0 deletions

View File

@ -0,0 +1,224 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use candle::{DType, Device, Tensor};
use candle_nn as nn;
use candle_transformers::models::chinese_clip::{ChineseClipConfig, ChineseClipModel};
use clap::Parser;
use tokenizers::Tokenizer;
#[derive(Parser)]
struct Args {
#[arg(long)]
model: Option<String>,
#[arg(long)]
tokenizer: Option<String>,
#[arg(long, use_value_delimiter = true)]
images: Option<Vec<String>>,
#[arg(long)]
cpu: bool,
#[arg(long, use_value_delimiter = true)]
sequences: Option<Vec<String>>,
}
fn main() -> anyhow::Result<()> {
let args = Args::parse();
tracing_subscriber::fmt::init();
let device = candle_examples::device(args.cpu)?;
let var = load_weights(args.model, &device)?;
let clip_model = ChineseClipModel::new(var, &ChineseClipConfig::clip_vit_base_patch16())?;
tracing::info!("Transformer loaded. ");
let (pixel_values, vec_imgs) = load_images(args.images, &device)?;
tracing::info!("Images loaded. ");
let tokenizer = load_tokenizer()?;
let (input_ids, type_ids, attention_mask, text_sequences) =
tokenize_sequences(args.sequences, &tokenizer, &device)?;
tracing::info!("Computing ... ");
let (_logits_per_text, logits_per_image) = clip_model.forward(
&pixel_values,
&input_ids,
Some(&type_ids),
Some(&attention_mask),
)?;
let softmax_image = nn::ops::softmax(&logits_per_image, 1)?;
let softmax_image_vec = softmax_image.flatten_all()?.to_vec1::<f32>()?;
let probability_vec = softmax_image_vec
.iter()
.map(|v| v * 100.0)
.collect::<Vec<f32>>();
let probability_per_image = probability_vec.len() / vec_imgs.len();
for (i, img) in vec_imgs.iter().enumerate() {
let start = i * probability_per_image;
let end = start + probability_per_image;
let prob = &probability_vec[start..end];
tracing::info!("\n\nResults for image: {}\n", img);
for (i, p) in prob.iter().enumerate() {
tracing::info!("Probability: {:.4}% Text: {} ", p, text_sequences[i]);
}
}
Ok(())
}
pub fn load_weights(model: Option<String>, device: &Device) -> anyhow::Result<nn::VarBuilder> {
let model_file = match model {
None => {
let api = hf_hub::api::sync::Api::new()?;
let repo = hf_hub::Repo::with_revision(
"OFA-Sys/chinese-clip-vit-base-patch16".to_string(),
hf_hub::RepoType::Model,
"refs/pr/3".to_string(),
);
let api = api.repo(repo);
api.get("model.safetensors")?
}
Some(model) => model.into(),
};
Ok(unsafe { nn::VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, device)? })
}
pub fn load_tokenizer() -> anyhow::Result<Tokenizer> {
let tokenizer_file = {
let api = hf_hub::api::sync::Api::new()?;
let repo = hf_hub::Repo::with_revision(
"OFA-Sys/chinese-clip-vit-base-patch16".to_string(),
hf_hub::RepoType::Model,
"refs/pr/3".to_string(),
);
let api = api.repo(repo);
api.get("tokenizer.json")?
};
Tokenizer::from_file(tokenizer_file).map_err(anyhow::Error::msg)
}
pub fn tokenize_sequences(
sequences: Option<Vec<String>>,
tokenizer: &Tokenizer,
device: &Device,
) -> anyhow::Result<(Tensor, Tensor, Tensor, Vec<String>)> {
let vec_seq = match sequences {
Some(seq) => seq,
None => vec![
"自行车比赛".to_string(),
"两只猫咪".to_string(),
"拿着蜡烛的机器人".to_string(),
],
};
let mut input_ids = vec![];
let mut type_ids = vec![];
let mut attention_mask = vec![];
let mut max_len = 0;
for seq in vec_seq.clone() {
let encoding = tokenizer.encode(seq, true).map_err(anyhow::Error::msg)?;
input_ids.push(encoding.get_ids().to_vec());
type_ids.push(encoding.get_type_ids().to_vec());
attention_mask.push(encoding.get_attention_mask().to_vec());
if encoding.get_ids().len() > max_len {
max_len = encoding.get_ids().len();
}
}
let pad_id = *tokenizer
.get_vocab(true)
.get("[PAD]")
.ok_or(anyhow::Error::msg("No pad token"))?;
let input_ids: Vec<Vec<u32>> = input_ids
.iter_mut()
.map(|item| {
item.extend(vec![pad_id; max_len - item.len()]);
item.to_vec()
})
.collect();
let type_ids: Vec<Vec<u32>> = type_ids
.iter_mut()
.map(|item| {
item.extend(vec![0; max_len - item.len()]);
item.to_vec()
})
.collect();
let attention_mask: Vec<Vec<u32>> = attention_mask
.iter_mut()
.map(|item| {
item.extend(vec![0; max_len - item.len()]);
item.to_vec()
})
.collect();
let input_ids = Tensor::new(input_ids, device)?;
let type_ids = Tensor::new(type_ids, device)?;
let attention_mask = Tensor::new(attention_mask, device)?;
Ok((input_ids, type_ids, attention_mask, vec_seq))
}
pub fn load_images(
images: Option<Vec<String>>,
device: &Device,
) -> anyhow::Result<(Tensor, Vec<String>)> {
let vec_imgs = match images {
Some(imgs) => imgs,
None => vec![
"candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg".to_string(),
"candle-examples/examples/yolo-v8/assets/bike.jpg".to_string(),
],
};
let mut images = vec![];
for path in vec_imgs.iter() {
let tensor = load_image(path, 224, device)?;
images.push(tensor);
}
let images = Tensor::stack(&images, 0)?.to_device(device)?;
Ok((images, vec_imgs))
}
fn load_image<T: AsRef<std::path::Path>>(
path: T,
image_size: usize,
device: &Device,
) -> anyhow::Result<Tensor> {
let img = image::ImageReader::open(path)?.decode()?;
let (height, width) = (image_size, image_size);
let img = img.resize_to_fill(
width as u32,
height as u32,
image::imageops::FilterType::Triangle,
);
let img = img.to_rgb8().into_raw();
let img = Tensor::from_vec(img, (height, width, 3), device)?.permute((2, 0, 1))?;
let mean = Tensor::new(&[0.48145466f32, 0.4578275, 0.40821073], device)?.reshape((3, 1, 1))?;
let std =
Tensor::new(&[0.26862954f32, 0.261_302_6, 0.275_777_1], device)?.reshape((3, 1, 1))?;
let img = (img.to_dtype(DType::F32)? / 255.)?
.broadcast_sub(&mean)?
.broadcast_div(&std)?;
Ok(img)
}

View File

@ -0,0 +1,208 @@
//! Chinese contrastive Language-Image Pre-Training
//!
//! Chinese contrastive Language-Image Pre-Training (CLIP) is an architecture trained on
//! pairs of images with related texts.
//!
//! https://github.com/OFA-Sys/Chinese-CLIP
//! https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/chinese_clip/modeling_chinese_clip.py
use candle::{Module, Result, Tensor, D};
use candle_nn as nn;
use text_model::ChineseClipTextTransformer;
use vision_model::ChineseClipVisionTransformer;
pub mod text_model;
pub mod vision_model;
#[derive(Debug, Clone, Copy)]
pub enum Activation {
QuickGelu,
Gelu,
GeluNew,
Relu,
}
impl From<String> for Activation {
fn from(value: String) -> Self {
match value.as_str() {
"quick_gelu" => Activation::QuickGelu,
"gelu" => Activation::Gelu,
"gelu_new" => Activation::GeluNew,
"relu" => Activation::Relu,
_ => panic!("Invalid activation function: {}", value),
}
}
}
impl Module for Activation {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
match self {
Activation::QuickGelu => xs * nn::ops::sigmoid(&(xs * 1.702f64)?)?,
Activation::Gelu => xs.gelu_erf(),
Activation::GeluNew => xs.gelu(),
Activation::Relu => xs.relu(),
}
}
}
#[derive(Clone, Debug)]
pub struct ChineseClipConfig {
pub text_config: text_model::ChineseClipTextConfig,
pub vision_config: vision_model::ChineseClipVisionConfig,
pub projection_dim: usize,
pub logit_scale_init_value: f32,
pub image_size: usize,
}
impl ChineseClipConfig {
/// referer: https://huggingface.co/OFA-Sys/chinese-clip-vit-base-patch16/blob/main/config.json
pub fn clip_vit_base_patch16() -> Self {
let text_config = text_model::ChineseClipTextConfig::clip_vit_base_patch16();
let vision_config = vision_model::ChineseClipVisionConfig::clip_vit_base_patch16();
Self {
text_config,
vision_config,
projection_dim: 512,
logit_scale_init_value: 2.6592,
image_size: 512,
}
}
}
#[derive(Clone, Debug)]
pub enum EncoderConfig {
Text(text_model::ChineseClipTextConfig),
Vision(vision_model::ChineseClipVisionConfig),
}
impl EncoderConfig {
pub fn embed_dim(&self) -> usize {
match self {
Self::Text(c) => c.hidden_size,
Self::Vision(c) => c.hidden_size,
}
}
pub fn num_attention_heads(&self) -> usize {
match self {
Self::Text(c) => c.num_attention_heads,
Self::Vision(c) => c.num_attention_heads,
}
}
pub fn intermediate_size(&self) -> usize {
match self {
Self::Text(c) => c.intermediate_size,
Self::Vision(c) => c.intermediate_size,
}
}
pub fn num_hidden_layers(&self) -> usize {
match self {
Self::Text(c) => c.num_hidden_layers,
Self::Vision(c) => c.num_hidden_layers,
}
}
pub fn activation(&self) -> Activation {
match self {
Self::Text(c) => c.hidden_act,
Self::Vision(c) => c.hidden_act,
}
}
pub fn layer_norm_eps(&self) -> f64 {
match self {
Self::Text(c) => c.layer_norm_eps,
Self::Vision(c) => c.layer_norm_eps,
}
}
}
#[derive(Clone, Debug)]
pub struct ChineseClipModel {
text_model: ChineseClipTextTransformer,
vision_model: ChineseClipVisionTransformer,
visual_projection: nn::Linear,
text_projection: nn::Linear,
logit_scale: Tensor,
}
impl ChineseClipModel {
pub fn new(vs: nn::VarBuilder, c: &ChineseClipConfig) -> Result<Self> {
let text_model = ChineseClipTextTransformer::new(vs.pp("text_model"), &c.text_config)?;
let vision_model =
ChineseClipVisionTransformer::new(vs.pp("vision_model"), &c.vision_config)?;
let vision_embed_dim = c.vision_config.hidden_size;
let vision_projection = nn::linear_no_bias(
vision_embed_dim,
c.projection_dim,
vs.pp("visual_projection"),
)?;
let text_embed_dim = c.text_config.hidden_size;
let text_projection =
nn::linear_no_bias(text_embed_dim, c.projection_dim, vs.pp("text_projection"))?;
let logit_scale = if vs.contains_tensor("logit_scale") {
vs.get(&[], "logit_scale")?
} else {
Tensor::new(&[c.logit_scale_init_value], vs.device())?
};
Ok(Self {
text_model,
vision_model,
visual_projection: vision_projection,
text_projection,
logit_scale,
})
}
pub fn get_text_features(
&self,
input_ids: &Tensor,
token_type_ids: Option<&Tensor>,
attention_mask: Option<&Tensor>,
) -> Result<Tensor> {
let output = self
.text_model
.forward(input_ids, token_type_ids, attention_mask)?;
self.text_projection.forward(&output)
}
pub fn get_image_features(&self, pixel_values: &Tensor) -> Result<Tensor> {
pixel_values
.apply(&self.vision_model)?
.apply(&self.visual_projection)
}
pub fn forward(
&self,
pixel_values: &Tensor,
input_ids: &Tensor,
token_type_ids: Option<&Tensor>,
attention_mask: Option<&Tensor>,
) -> Result<(Tensor, Tensor)> {
let image_features = self.get_image_features(pixel_values)?;
let text_features = self.get_text_features(input_ids, token_type_ids, attention_mask)?;
let image_features_normalized = div_l2_norm(&image_features)?;
let text_features_normalized = div_l2_norm(&text_features)?;
let logits_per_text = text_features_normalized.matmul(&image_features_normalized.t()?)?;
let logit_scale = self.logit_scale.exp()?;
let logits_per_text = logits_per_text.broadcast_mul(&logit_scale)?;
let logits_per_image = logits_per_text.t()?;
Ok((logits_per_text, logits_per_image))
}
}
pub fn div_l2_norm(v: &Tensor) -> Result<Tensor> {
let l2_norm = v.sqr()?.sum_keepdim(D::Minus1)?.sqrt()?;
v.broadcast_div(&l2_norm)
}

View File

@ -0,0 +1,540 @@
//! Chinese contrastive Language-Image Pre-Training
//!
//! Chinese contrastive Language-Image Pre-Training (CLIP) is an architecture trained on
//! pairs of images with related texts.
//!
//! https://github.com/OFA-Sys/Chinese-CLIP
//! https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/chinese_clip/modeling_chinese_clip.py
use candle::{DType, Device, IndexOp, Module, Result, Tensor};
use candle_nn as nn;
use super::Activation;
/// Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For
/// positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to
/// [Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155).
/// For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models
/// with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658).
#[derive(Clone, Debug)]
pub enum PositionEmbeddingType {
Absolute,
RelativeKey,
RelativeKeyQuery,
}
#[derive(Clone, Debug)]
pub struct ChineseClipTextConfig {
pub vocab_size: usize,
pub hidden_size: usize,
pub num_hidden_layers: usize,
pub num_attention_heads: usize,
pub intermediate_size: usize,
pub hidden_act: Activation,
pub hidden_dropout_prob: f32,
pub attention_probs_dropout_prob: f64,
pub max_position_embeddings: usize,
pub type_vocab_size: usize,
pub initializer_range: f64,
pub initializer_factor: f64,
pub layer_norm_eps: f64,
pub pad_token_id: usize,
pub position_embedding_type: PositionEmbeddingType,
pub use_cache: bool,
}
impl Default for ChineseClipTextConfig {
fn default() -> Self {
Self {
vocab_size: 30522,
hidden_size: 768,
num_hidden_layers: 12,
num_attention_heads: 12,
intermediate_size: 3072,
hidden_act: Activation::Gelu,
hidden_dropout_prob: 0.1,
attention_probs_dropout_prob: 0.1,
max_position_embeddings: 512,
type_vocab_size: 2,
initializer_range: 0.02,
initializer_factor: 1.0,
layer_norm_eps: 1e-12,
pad_token_id: 0,
position_embedding_type: PositionEmbeddingType::Absolute,
use_cache: true,
}
}
}
impl ChineseClipTextConfig {
/// referer: https://huggingface.co/OFA-Sys/chinese-clip-vit-base-patch16/blob/main/config.json
pub fn clip_vit_base_patch16() -> Self {
Self {
vocab_size: 21128,
hidden_size: 768,
num_hidden_layers: 12,
num_attention_heads: 12,
intermediate_size: 3072,
hidden_act: Activation::Gelu,
hidden_dropout_prob: 0.1,
attention_probs_dropout_prob: 0.1,
max_position_embeddings: 512,
type_vocab_size: 2,
initializer_range: 0.02,
initializer_factor: 1.0,
layer_norm_eps: 1e-12,
pad_token_id: 0,
position_embedding_type: PositionEmbeddingType::Absolute,
use_cache: true,
}
}
}
#[derive(Clone, Debug)]
pub struct ChineseClipTextEmbeddings {
word_embeddings: nn::Embedding,
position_embeddings: nn::Embedding,
token_type_embeddings: nn::Embedding,
layer_norm: nn::LayerNorm,
dropout: nn::Dropout,
position_embedding_type: PositionEmbeddingType,
position_ids: Tensor,
token_type_ids: Tensor,
}
impl ChineseClipTextEmbeddings {
pub fn new(var: nn::VarBuilder, config: &ChineseClipTextConfig) -> Result<Self> {
let word_embeddings = nn::embedding(
config.vocab_size,
config.hidden_size,
var.pp("word_embeddings"),
)?;
let position_embeddings = nn::embedding(
config.max_position_embeddings,
config.hidden_size,
var.pp("position_embeddings"),
)?;
let token_type_embeddings = nn::embedding(
config.type_vocab_size,
config.hidden_size,
var.pp("token_type_embeddings"),
)?;
let layer_norm = nn::layer_norm::<f64>(
config.hidden_size,
config.layer_norm_eps,
var.pp("LayerNorm"),
)?;
let dropout = nn::Dropout::new(config.hidden_dropout_prob);
let position_ids =
Tensor::arange(0u32, config.max_position_embeddings as u32, var.device())?
.unsqueeze(0)?;
let token_type_ids = Tensor::zeros(position_ids.shape(), DType::I64, var.device())?;
Ok(Self {
word_embeddings,
position_embeddings,
token_type_embeddings,
layer_norm,
dropout,
position_embedding_type: config.position_embedding_type.clone(),
position_ids,
token_type_ids,
})
}
fn forward(&self, xs: &Tensor, token_type_ids: Option<&Tensor>) -> Result<Tensor> {
let (_batch_size, seq_length) = xs.dims2()?;
let position_ids = (0..seq_length as u32).collect::<Vec<_>>();
let position_ids = self.position_ids.index_select(
&Tensor::new(&position_ids[..], self.position_ids.device())?,
1,
)?;
let word_embeddings = self.word_embeddings.forward(xs)?;
let token_type_ids = match token_type_ids {
Some(token_type_ids) => token_type_ids,
None => &self.token_type_ids.i((.., 0..seq_length))?,
};
let token_type_ids = token_type_ids.expand(xs.shape())?;
let token_type_embeddings = self.token_type_embeddings.forward(&token_type_ids)?;
let embeddings = (&word_embeddings + token_type_embeddings)?;
let embeddings = match self.position_embedding_type {
PositionEmbeddingType::Absolute => {
let position_embeddings = self.position_embeddings.forward(&position_ids)?;
let position_embeddings = position_embeddings.expand(embeddings.shape())?;
(embeddings + position_embeddings)?
}
_ => embeddings,
};
let embeddings = self.layer_norm.forward(&embeddings)?;
let embeddings = self.dropout.forward(&embeddings, false)?;
Ok(embeddings)
}
}
/// Copied from [`crate::models::bert::BertSelfOutput`] to [`ChineseClipTextSelfOutput`]
#[derive(Clone, Debug)]
struct ChineseClipTextSelfOutput {
dense: nn::Linear,
layer_norm: nn::LayerNorm,
dropout: nn::Dropout,
span: tracing::Span,
}
impl ChineseClipTextSelfOutput {
fn new(var: nn::VarBuilder, config: &ChineseClipTextConfig) -> Result<Self> {
let dense = nn::linear(config.hidden_size, config.hidden_size, var.pp("dense"))?;
let layer_norm = nn::layer_norm(
config.hidden_size,
config.layer_norm_eps,
var.pp("LayerNorm"),
)?;
let dropout = nn::Dropout::new(config.hidden_dropout_prob);
Ok(Self {
dense,
layer_norm,
dropout,
span: tracing::span!(tracing::Level::TRACE, "self-out"),
})
}
fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let hidden_states = self.dense.forward(hidden_states)?;
let hidden_states = self.dropout.forward(&hidden_states, false)?;
self.layer_norm.forward(&(hidden_states + input_tensor)?)
}
}
/// Copied from [`crate::models::bert::BertSelfAttention`] to [`ChineseClipTextSelfAttention`]
#[derive(Clone, Debug)]
struct ChineseClipTextSelfAttention {
query: nn::Linear,
key: nn::Linear,
value: nn::Linear,
dropout: nn::Dropout,
num_attention_heads: usize,
attention_head_size: usize,
span: tracing::Span,
span_softmax: tracing::Span,
}
impl ChineseClipTextSelfAttention {
fn new(var: nn::VarBuilder, config: &ChineseClipTextConfig) -> Result<Self> {
let attention_head_size = config.hidden_size / config.num_attention_heads;
let all_head_size = config.num_attention_heads * attention_head_size;
let dropout = nn::Dropout::new(config.hidden_dropout_prob);
let hidden_size = config.hidden_size;
let query = nn::linear(hidden_size, all_head_size, var.pp("query"))?;
let value = nn::linear(hidden_size, all_head_size, var.pp("value"))?;
let key = nn::linear(hidden_size, all_head_size, var.pp("key"))?;
Ok(Self {
query,
key,
value,
dropout,
num_attention_heads: config.num_attention_heads,
attention_head_size,
span: tracing::span!(tracing::Level::TRACE, "self-attn"),
span_softmax: tracing::span!(tracing::Level::TRACE, "softmax"),
})
}
fn transpose_for_scores(&self, xs: &Tensor) -> Result<Tensor> {
let mut new_x_shape = xs.dims().to_vec();
new_x_shape.pop();
new_x_shape.push(self.num_attention_heads);
new_x_shape.push(self.attention_head_size);
let xs = xs.reshape(new_x_shape.as_slice())?.transpose(1, 2)?;
xs.contiguous()
}
fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let query_layer = self.query.forward(hidden_states)?;
let key_layer = self.key.forward(hidden_states)?;
let value_layer = self.value.forward(hidden_states)?;
let query_layer = self.transpose_for_scores(&query_layer)?;
let key_layer = self.transpose_for_scores(&key_layer)?;
let value_layer = self.transpose_for_scores(&value_layer)?;
let attention_scores = query_layer.matmul(&key_layer.t()?)?;
let attention_scores = (attention_scores / (self.attention_head_size as f64).sqrt())?;
let attention_scores = attention_scores.broadcast_add(attention_mask)?;
let attention_probs = {
let _enter_sm = self.span_softmax.enter();
nn::ops::softmax(&attention_scores, candle::D::Minus1)?
};
let attention_probs = self.dropout.forward(&attention_probs, false)?;
let context_layer = attention_probs.matmul(&value_layer)?;
let context_layer = context_layer.transpose(1, 2)?.contiguous()?;
let context_layer = context_layer.flatten_from(candle::D::Minus2)?;
Ok(context_layer)
}
}
/// Copied from [`crate::models::bert::BertAttention`] to [`ChineseClipTextAttention`]
#[derive(Clone, Debug)]
struct ChineseClipTextAttention {
self_attention: ChineseClipTextSelfAttention,
self_output: ChineseClipTextSelfOutput,
span: tracing::Span,
}
impl ChineseClipTextAttention {
fn new(var: nn::VarBuilder, config: &ChineseClipTextConfig) -> Result<Self> {
let self_attention = ChineseClipTextSelfAttention::new(var.pp("self"), config)?;
let self_output = ChineseClipTextSelfOutput::new(var.pp("output"), config)?;
Ok(Self {
self_attention,
self_output,
span: tracing::span!(tracing::Level::TRACE, "attn"),
})
}
fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let self_outputs = self.self_attention.forward(hidden_states, attention_mask)?;
let attention_output = self.self_output.forward(&self_outputs, hidden_states)?;
Ok(attention_output)
}
}
type HiddenActLayer = Activation;
/// Copied from [`crate::models::bert::BertIntermediate`] to [`ChineseClipTextIntermediate`]
#[derive(Clone, Debug)]
struct ChineseClipTextIntermediate {
dense: nn::Linear,
intermediate_act: HiddenActLayer,
span: tracing::Span,
}
impl ChineseClipTextIntermediate {
fn new(var: nn::VarBuilder, config: &ChineseClipTextConfig) -> Result<Self> {
let dense = nn::linear(
config.hidden_size,
config.intermediate_size,
var.pp("dense"),
)?;
Ok(Self {
dense,
intermediate_act: config.hidden_act,
span: tracing::span!(tracing::Level::TRACE, "inter"),
})
}
}
impl Module for ChineseClipTextIntermediate {
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let hidden_states = self.dense.forward(hidden_states)?;
let ys = self.intermediate_act.forward(&hidden_states)?;
Ok(ys)
}
}
/// Copied from [`crate::models::bert::BertOutput`] to [`ChineseClipTextOutput`]
#[derive(Clone, Debug)]
struct ChineseClipTextOutput {
dense: nn::Linear,
layer_norm: nn::LayerNorm,
dropout: nn::Dropout,
span: tracing::Span,
}
impl ChineseClipTextOutput {
fn new(var: nn::VarBuilder, config: &ChineseClipTextConfig) -> Result<Self> {
let dense = nn::linear(
config.intermediate_size,
config.hidden_size,
var.pp("dense"),
)?;
let layer_norm = nn::layer_norm(
config.hidden_size,
config.layer_norm_eps,
var.pp("LayerNorm"),
)?;
let dropout = nn::Dropout::new(config.hidden_dropout_prob);
Ok(Self {
dense,
layer_norm,
dropout,
span: tracing::span!(tracing::Level::TRACE, "out"),
})
}
fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let hidden_states = self.dense.forward(hidden_states)?;
let hidden_states = self.dropout.forward(&hidden_states, false)?;
self.layer_norm.forward(&(hidden_states + input_tensor)?)
}
}
/// Copied from [`crate::models::bert::BertLayer`] to [`ChineseClipTextLayer`]
#[derive(Clone, Debug)]
struct ChineseClipTextLayer {
attention: ChineseClipTextAttention,
intermediate: ChineseClipTextIntermediate,
output: ChineseClipTextOutput,
span: tracing::Span,
}
impl ChineseClipTextLayer {
fn new(var: nn::VarBuilder, config: &ChineseClipTextConfig) -> Result<Self> {
let attention = ChineseClipTextAttention::new(var.pp("attention"), config)?;
let intermediate = ChineseClipTextIntermediate::new(var.pp("intermediate"), config)?;
let output = ChineseClipTextOutput::new(var.pp("output"), config)?;
Ok(Self {
attention,
intermediate,
output,
span: tracing::span!(tracing::Level::TRACE, "layer"),
})
}
fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let attention_output = self.attention.forward(hidden_states, attention_mask)?;
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L523
let intermediate_output = self.intermediate.forward(&attention_output)?;
let layer_output = self
.output
.forward(&intermediate_output, &attention_output)?;
Ok(layer_output)
}
}
#[derive(Clone, Debug)]
struct Tanh;
impl Tanh {
pub fn new() -> Self {
Self {}
}
}
impl Module for Tanh {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
xs.tanh()
}
}
#[derive(Clone, Debug)]
struct ChineseClipTextPooler {
dense: nn::Linear,
activation: Tanh,
}
impl ChineseClipTextPooler {
pub fn new(var: nn::VarBuilder, config: &ChineseClipTextConfig) -> Result<Self> {
let dense = nn::linear(config.hidden_size, config.hidden_size, var.pp("dense"))?;
let activation = Tanh::new();
Ok(Self { dense, activation })
}
}
impl Module for ChineseClipTextPooler {
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
let first_token_tensor = hidden_states.i((.., 0))?;
let pooled_output = self.dense.forward(&first_token_tensor)?;
let pooled_output = self.activation.forward(&pooled_output)?;
Ok(pooled_output)
}
}
#[derive(Clone, Debug)]
struct ChineseClipTextEncoder {
layers: Vec<ChineseClipTextLayer>,
span: tracing::Span,
}
impl ChineseClipTextEncoder {
fn new(var: nn::VarBuilder, config: &ChineseClipTextConfig) -> Result<Self> {
let layers = (0..config.num_hidden_layers)
.map(|index| ChineseClipTextLayer::new(var.pp(format!("layer.{index}")), config))
.collect::<Result<Vec<_>>>()?;
let span = tracing::span!(tracing::Level::TRACE, "encoder");
Ok(ChineseClipTextEncoder { layers, span })
}
fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let mut hidden_states = hidden_states.clone();
// Use a loop rather than a fold as it's easier to modify when adding debug/...
for layer in self.layers.iter() {
hidden_states = layer.forward(&hidden_states, attention_mask)?
}
Ok(hidden_states)
}
}
#[derive(Clone, Debug)]
pub struct ChineseClipTextTransformer {
embeddings: ChineseClipTextEmbeddings,
encoder: ChineseClipTextEncoder,
pooler: Option<ChineseClipTextPooler>,
pub device: Device,
span: tracing::Span,
}
impl ChineseClipTextTransformer {
pub fn new(var: nn::VarBuilder, config: &ChineseClipTextConfig) -> Result<Self> {
let embeddings = ChineseClipTextEmbeddings::new(var.pp("embeddings"), config)?;
let encoder = ChineseClipTextEncoder::new(var.pp("encoder"), config)?;
// see: https://github.com/huggingface/transformers/blob/e40bb4845e0eefb52ec1e9cac9c2446ab36aef81/src/transformers/models/chinese_clip/modeling_chinese_clip.py#L1362
// In the original Python version of the code, the pooler is not used, and there are no parameters for the pooler in the weight file.
let pooler = if var.contains_tensor("pooler") {
Some(ChineseClipTextPooler::new(var.pp("pooler"), config)?)
} else {
None
};
Ok(Self {
embeddings,
encoder,
pooler,
device: var.device().clone(),
span: tracing::span!(tracing::Level::TRACE, "model"),
})
}
pub fn forward(
&self,
input_ids: &Tensor,
token_type_ids: Option<&Tensor>,
attention_mask: Option<&Tensor>,
) -> Result<Tensor> {
let _enter = self.span.enter();
let embedding_output = self.embeddings.forward(input_ids, token_type_ids)?;
let attention_mask = match attention_mask {
Some(attention_mask) => attention_mask.clone(),
None => input_ids.ones_like()?,
};
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L995
let attention_mask = get_extended_attention_mask(&attention_mask, DType::F32)?;
let encoder_outputs = self.encoder.forward(&embedding_output, &attention_mask)?;
let encoder_output = encoder_outputs.i((.., 0, ..))?;
let pooled_output = match &self.pooler {
Some(pooler) => pooler.forward(&encoder_output)?,
None => encoder_output,
};
Ok(pooled_output)
}
}
fn get_extended_attention_mask(attention_mask: &Tensor, dtype: DType) -> Result<Tensor> {
let attention_mask = match attention_mask.rank() {
3 => attention_mask.unsqueeze(1)?,
2 => attention_mask.unsqueeze(1)?.unsqueeze(1)?,
_ => candle::bail!("Wrong shape for input_ids or attention_mask"),
};
let attention_mask = attention_mask.to_dtype(dtype)?;
// torch.finfo(dtype).min
(attention_mask.ones_like()? - &attention_mask)?
.broadcast_mul(&Tensor::try_from(f32::MIN)?.to_device(attention_mask.device())?)
}

View File

@ -0,0 +1,385 @@
//! Chinese contrastive Language-Image Pre-Training
//!
//! Chinese contrastive Language-Image Pre-Training (CLIP) is an architecture trained on
//! pairs of images with related texts.
//!
//! https://github.com/OFA-Sys/Chinese-CLIP
//! https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/chinese_clip/modeling_chinese_clip.py
use candle::{DType, IndexOp, Module, Result, Shape, Tensor, D};
use candle_nn as nn;
use super::{Activation, EncoderConfig};
#[derive(Clone, Debug)]
pub struct ChineseClipVisionConfig {
pub hidden_size: usize,
pub intermediate_size: usize,
pub projection_dim: usize,
pub num_hidden_layers: usize,
pub num_attention_heads: usize,
pub num_channels: usize,
pub image_size: usize,
pub patch_size: usize,
pub hidden_act: Activation,
pub layer_norm_eps: f64,
pub attention_dropout: f32,
pub initializer_range: f32,
pub initializer_factor: f32,
}
impl Default for ChineseClipVisionConfig {
fn default() -> Self {
ChineseClipVisionConfig {
hidden_size: 768,
intermediate_size: 3072,
projection_dim: 512,
num_hidden_layers: 12,
num_attention_heads: 12,
num_channels: 3,
image_size: 224,
patch_size: 32,
hidden_act: Activation::QuickGelu,
layer_norm_eps: 1e-5,
attention_dropout: 0.0,
initializer_range: 0.02,
initializer_factor: 1.0,
}
}
}
impl ChineseClipVisionConfig {
/// referer: https://huggingface.co/OFA-Sys/chinese-clip-vit-base-patch16/blob/main/config.json
pub fn clip_vit_base_patch16() -> Self {
Self {
hidden_size: 768,
intermediate_size: 3072,
projection_dim: 512,
num_hidden_layers: 12,
num_attention_heads: 12,
num_channels: 3,
image_size: 224,
patch_size: 16,
hidden_act: Activation::QuickGelu,
layer_norm_eps: 1e-5,
attention_dropout: 0.0,
initializer_range: 0.02,
initializer_factor: 1.0,
}
}
}
#[derive(Clone, Debug)]
pub struct ChineseClipVisionEmbeddings {
patch_embedding: nn::Conv2d,
position_ids: Tensor,
class_embedding: Tensor,
position_embedding: nn::Embedding,
}
impl ChineseClipVisionEmbeddings {
pub fn new(var: nn::VarBuilder, config: &ChineseClipVisionConfig) -> Result<Self> {
let embed_dim = config.hidden_size;
// originally nn.Parameter
let class_embedding = if var.contains_tensor("class_embedding") {
var.get(embed_dim, "class_embedding")?
} else {
Tensor::randn(0f32, 1f32, embed_dim, var.device())?
};
let num_patches = (config.image_size / config.patch_size).pow(2);
let num_positions = num_patches + 1;
let position_ids = Tensor::arange(0, num_positions as i64, var.device())?;
let conv2dconfig = nn::Conv2dConfig {
stride: config.patch_size,
..Default::default()
};
let position_embedding =
nn::embedding(num_positions, embed_dim, var.pp("position_embedding"))?;
let patch_embedding = nn::conv2d_no_bias(
config.num_channels,
embed_dim,
config.patch_size,
conv2dconfig,
var.pp("patch_embedding"),
)?;
Ok(Self {
patch_embedding,
position_ids,
class_embedding,
position_embedding,
})
}
}
impl Module for ChineseClipVisionEmbeddings {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let batch_size = xs.shape().dims();
let patch_embeds = self
.patch_embedding
.forward(xs)?
.flatten_from(2)?
.transpose(1, 2)?;
let shape = Shape::from((batch_size[0], 1, self.class_embedding.dim(D::Minus1)?));
let class_embeds = self.class_embedding.expand(shape)?;
let embeddings = Tensor::cat(&[class_embeds, patch_embeds], 1)?;
let position_embedding = self.position_embedding.forward(&self.position_ids)?;
embeddings.broadcast_add(&position_embedding)
}
}
#[derive(Clone, Debug)]
struct ChineseClipVisionAttention {
k_proj: nn::Linear,
v_proj: nn::Linear,
q_proj: nn::Linear,
out_proj: nn::Linear,
head_dim: usize,
scale: f64,
num_attention_heads: usize,
}
impl ChineseClipVisionAttention {
fn new(var: nn::VarBuilder, config: &EncoderConfig) -> Result<Self> {
let embed_dim = config.embed_dim();
let num_attention_heads = config.num_attention_heads();
let k_proj = nn::linear(embed_dim, embed_dim, var.pp("k_proj"))?;
let v_proj = nn::linear(embed_dim, embed_dim, var.pp("v_proj"))?;
let q_proj = nn::linear(embed_dim, embed_dim, var.pp("q_proj"))?;
let out_proj = nn::linear(embed_dim, embed_dim, var.pp("out_proj"))?;
let head_dim = embed_dim / num_attention_heads;
let scale = (head_dim as f64).powf(-0.5);
Ok(ChineseClipVisionAttention {
k_proj,
v_proj,
q_proj,
out_proj,
head_dim,
scale,
num_attention_heads,
})
}
fn shape(&self, xs: &Tensor, seq_len: usize, bsz: usize) -> Result<Tensor> {
xs.reshape((bsz, seq_len, self.num_attention_heads, self.head_dim))?
.transpose(1, 2)?
.contiguous()
}
fn forward(&self, xs: &Tensor, causal_attention_mask: Option<&Tensor>) -> Result<Tensor> {
let in_dtype = xs.dtype();
let (bsz, seq_len, embed_dim) = xs.dims3()?;
let proj_shape = (bsz * self.num_attention_heads, seq_len, self.head_dim);
let query_states = self
.shape(&(self.q_proj.forward(xs)? * self.scale)?, seq_len, bsz)?
.reshape(proj_shape)?
.to_dtype(DType::F32)?;
let key_states = self
.shape(&self.k_proj.forward(xs)?, seq_len, bsz)?
.reshape(proj_shape)?
.to_dtype(DType::F32)?;
let value_states = self
.shape(&self.v_proj.forward(xs)?, seq_len, bsz)?
.reshape(proj_shape)?
.to_dtype(DType::F32)?;
let attn_weights = query_states.matmul(&key_states.transpose(1, 2)?)?;
let src_len = key_states.dim(1)?;
let attn_weights = if let Some(causal_attention_mask) = causal_attention_mask {
attn_weights
.reshape((bsz, self.num_attention_heads, seq_len, src_len))?
.broadcast_add(causal_attention_mask)?
.reshape((bsz * self.num_attention_heads, seq_len, src_len))?
} else {
attn_weights
};
let attn_weights = nn::ops::softmax(&attn_weights, D::Minus1)?;
let attn_output = attn_weights.matmul(&value_states)?.to_dtype(in_dtype)?;
let attn_output = attn_output
.reshape((bsz, self.num_attention_heads, seq_len, self.head_dim))?
.transpose(1, 2)?
.reshape((bsz, seq_len, embed_dim))?;
self.out_proj.forward(&attn_output)
}
}
#[derive(Clone, Debug)]
struct ChineseClipVisionMlp {
fc1: nn::Linear,
fc2: nn::Linear,
activation: Activation,
}
impl ChineseClipVisionMlp {
fn new(var: nn::VarBuilder, config: &EncoderConfig) -> Result<Self> {
let fc1 = nn::linear(
config.embed_dim(),
config.intermediate_size(),
var.pp("fc1"),
)?;
let fc2 = nn::linear(
config.intermediate_size(),
config.embed_dim(),
var.pp("fc2"),
)?;
Ok(ChineseClipVisionMlp {
fc1,
fc2,
activation: config.activation(),
})
}
}
impl ChineseClipVisionMlp {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs = self.fc1.forward(xs)?;
self.fc2.forward(&self.activation.forward(&xs)?)
}
}
#[derive(Clone, Debug)]
struct ChineseClipVisionEncoderLayer {
self_attn: ChineseClipVisionAttention,
layer_norm1: nn::LayerNorm,
mlp: ChineseClipVisionMlp,
layer_norm2: nn::LayerNorm,
}
impl ChineseClipVisionEncoderLayer {
fn new(var: nn::VarBuilder, config: &EncoderConfig) -> Result<Self> {
let self_attn = ChineseClipVisionAttention::new(var.pp("self_attn"), config)?;
let layer_norm1 = nn::layer_norm(
config.embed_dim(),
config.layer_norm_eps(),
var.pp("layer_norm1"),
)?;
let mlp = ChineseClipVisionMlp::new(var.pp("mlp"), config)?;
let layer_norm2 = nn::layer_norm(
config.embed_dim(),
config.layer_norm_eps(),
var.pp("layer_norm2"),
)?;
Ok(ChineseClipVisionEncoderLayer {
self_attn,
layer_norm1,
mlp,
layer_norm2,
})
}
fn forward(&self, xs: &Tensor, causal_attention_mask: Option<&Tensor>) -> Result<Tensor> {
let residual = xs;
let xs = self.layer_norm1.forward(xs)?;
let xs = self.self_attn.forward(&xs, causal_attention_mask)?;
let xs = (xs + residual)?;
let residual = &xs;
let xs = self.layer_norm2.forward(&xs)?;
let xs = self.mlp.forward(&xs)?;
xs + residual
}
}
#[derive(Clone, Debug)]
pub struct ChineseClipVisionEncoder {
layers: Vec<ChineseClipVisionEncoderLayer>,
}
impl ChineseClipVisionEncoder {
pub fn new(var: nn::VarBuilder, config: &EncoderConfig) -> Result<Self> {
let vs = var.pp("layers");
let mut layers: Vec<ChineseClipVisionEncoderLayer> = Vec::new();
for index in 0..config.num_hidden_layers() {
let layer = ChineseClipVisionEncoderLayer::new(vs.pp(index.to_string()), config)?;
layers.push(layer)
}
Ok(ChineseClipVisionEncoder { layers })
}
pub fn forward(&self, xs: &Tensor, causal_attention_mask: Option<&Tensor>) -> Result<Tensor> {
let mut xs = xs.clone();
for layer in self.layers.iter() {
xs = layer.forward(&xs, causal_attention_mask)?;
}
Ok(xs)
}
// required by LLaVA
pub fn output_hidden_states(
&self,
xs: &Tensor,
causal_attention_mask: Option<&Tensor>,
) -> Result<Vec<Tensor>> {
let mut xs = xs.clone();
let mut hidden_states = Vec::new();
for layer in self.layers.iter() {
xs = layer.forward(&xs, causal_attention_mask)?;
hidden_states.push(xs.clone());
}
Ok(hidden_states)
}
}
#[derive(Clone, Debug)]
pub struct ChineseClipVisionTransformer {
embeddings: ChineseClipVisionEmbeddings,
encoder: ChineseClipVisionEncoder,
pre_layer_norm: nn::LayerNorm,
final_layer_norm: nn::LayerNorm,
}
impl ChineseClipVisionTransformer {
pub fn new(var: nn::VarBuilder, config: &ChineseClipVisionConfig) -> Result<Self> {
let embed_dim = config.hidden_size;
let embeddings = ChineseClipVisionEmbeddings::new(var.pp("embeddings"), config)?;
let pre_layer_norm =
nn::layer_norm(embed_dim, config.layer_norm_eps, var.pp("pre_layrnorm"))?;
let encoder = ChineseClipVisionEncoder::new(
var.pp("encoder"),
&EncoderConfig::Vision(config.clone()),
)?;
let final_layer_norm =
nn::layer_norm(embed_dim, config.layer_norm_eps, var.pp("post_layernorm"))?;
Ok(Self {
embeddings,
encoder,
final_layer_norm,
pre_layer_norm,
})
}
// required by LLaVA
pub fn output_hidden_states(&self, pixel_values: &Tensor) -> Result<Vec<Tensor>> {
let hidden_states = pixel_values
.apply(&self.embeddings)?
.apply(&self.pre_layer_norm)?;
let mut result = self.encoder.output_hidden_states(&hidden_states, None)?;
let encoder_outputs = result.last().unwrap();
let pooled_output = encoder_outputs.i((.., 0, ..))?;
result.push(self.final_layer_norm.forward(&pooled_output)?.clone());
Ok(result)
}
}
impl Module for ChineseClipVisionTransformer {
fn forward(&self, pixel_values: &Tensor) -> Result<Tensor> {
let hidden_states = pixel_values
.apply(&self.embeddings)?
.apply(&self.pre_layer_norm)?;
let encoder_outputs = self.encoder.forward(&hidden_states, None)?;
// referer: https://github.com/huggingface/transformers/blob/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip/modeling_clip.py#L787
let pooled_output = encoder_outputs.i((.., 0, ..))?;
self.final_layer_norm.forward(&pooled_output)
}
}

View File

@ -5,6 +5,7 @@ pub mod bigcode;
pub mod blip;
pub mod blip_text;
pub mod chatglm;
pub mod chinese_clip;
pub mod clip;
pub mod codegeex4_9b;
pub mod colpali;