mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
feat: intergrate chinese clip and add example (#2555)
* start to impl chinese clip * impl vision model * copy code from bert * refactor use * refactor use again * fix text model * refactor * try to fix text model * tuning * tuning chinese clip * delete useless code * revert code * Clippy fixes. * Also apply cargo fmt. --------- Co-authored-by: laurent <laurent.mazare@gmail.com>
This commit is contained in:
208
candle-transformers/src/models/chinese_clip/mod.rs
Normal file
208
candle-transformers/src/models/chinese_clip/mod.rs
Normal file
@ -0,0 +1,208 @@
|
||||
//! Chinese contrastive Language-Image Pre-Training
|
||||
//!
|
||||
//! Chinese contrastive Language-Image Pre-Training (CLIP) is an architecture trained on
|
||||
//! pairs of images with related texts.
|
||||
//!
|
||||
//! https://github.com/OFA-Sys/Chinese-CLIP
|
||||
//! https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/chinese_clip/modeling_chinese_clip.py
|
||||
|
||||
use candle::{Module, Result, Tensor, D};
|
||||
use candle_nn as nn;
|
||||
|
||||
use text_model::ChineseClipTextTransformer;
|
||||
use vision_model::ChineseClipVisionTransformer;
|
||||
|
||||
pub mod text_model;
|
||||
pub mod vision_model;
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum Activation {
|
||||
QuickGelu,
|
||||
Gelu,
|
||||
GeluNew,
|
||||
Relu,
|
||||
}
|
||||
|
||||
impl From<String> for Activation {
|
||||
fn from(value: String) -> Self {
|
||||
match value.as_str() {
|
||||
"quick_gelu" => Activation::QuickGelu,
|
||||
"gelu" => Activation::Gelu,
|
||||
"gelu_new" => Activation::GeluNew,
|
||||
"relu" => Activation::Relu,
|
||||
_ => panic!("Invalid activation function: {}", value),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Activation {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
match self {
|
||||
Activation::QuickGelu => xs * nn::ops::sigmoid(&(xs * 1.702f64)?)?,
|
||||
Activation::Gelu => xs.gelu_erf(),
|
||||
Activation::GeluNew => xs.gelu(),
|
||||
Activation::Relu => xs.relu(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ChineseClipConfig {
|
||||
pub text_config: text_model::ChineseClipTextConfig,
|
||||
pub vision_config: vision_model::ChineseClipVisionConfig,
|
||||
pub projection_dim: usize,
|
||||
pub logit_scale_init_value: f32,
|
||||
pub image_size: usize,
|
||||
}
|
||||
|
||||
impl ChineseClipConfig {
|
||||
/// referer: https://huggingface.co/OFA-Sys/chinese-clip-vit-base-patch16/blob/main/config.json
|
||||
pub fn clip_vit_base_patch16() -> Self {
|
||||
let text_config = text_model::ChineseClipTextConfig::clip_vit_base_patch16();
|
||||
let vision_config = vision_model::ChineseClipVisionConfig::clip_vit_base_patch16();
|
||||
|
||||
Self {
|
||||
text_config,
|
||||
vision_config,
|
||||
projection_dim: 512,
|
||||
logit_scale_init_value: 2.6592,
|
||||
image_size: 512,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum EncoderConfig {
|
||||
Text(text_model::ChineseClipTextConfig),
|
||||
Vision(vision_model::ChineseClipVisionConfig),
|
||||
}
|
||||
|
||||
impl EncoderConfig {
|
||||
pub fn embed_dim(&self) -> usize {
|
||||
match self {
|
||||
Self::Text(c) => c.hidden_size,
|
||||
Self::Vision(c) => c.hidden_size,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn num_attention_heads(&self) -> usize {
|
||||
match self {
|
||||
Self::Text(c) => c.num_attention_heads,
|
||||
Self::Vision(c) => c.num_attention_heads,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn intermediate_size(&self) -> usize {
|
||||
match self {
|
||||
Self::Text(c) => c.intermediate_size,
|
||||
Self::Vision(c) => c.intermediate_size,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn num_hidden_layers(&self) -> usize {
|
||||
match self {
|
||||
Self::Text(c) => c.num_hidden_layers,
|
||||
Self::Vision(c) => c.num_hidden_layers,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn activation(&self) -> Activation {
|
||||
match self {
|
||||
Self::Text(c) => c.hidden_act,
|
||||
Self::Vision(c) => c.hidden_act,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn layer_norm_eps(&self) -> f64 {
|
||||
match self {
|
||||
Self::Text(c) => c.layer_norm_eps,
|
||||
Self::Vision(c) => c.layer_norm_eps,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ChineseClipModel {
|
||||
text_model: ChineseClipTextTransformer,
|
||||
vision_model: ChineseClipVisionTransformer,
|
||||
visual_projection: nn::Linear,
|
||||
text_projection: nn::Linear,
|
||||
logit_scale: Tensor,
|
||||
}
|
||||
|
||||
impl ChineseClipModel {
|
||||
pub fn new(vs: nn::VarBuilder, c: &ChineseClipConfig) -> Result<Self> {
|
||||
let text_model = ChineseClipTextTransformer::new(vs.pp("text_model"), &c.text_config)?;
|
||||
|
||||
let vision_model =
|
||||
ChineseClipVisionTransformer::new(vs.pp("vision_model"), &c.vision_config)?;
|
||||
|
||||
let vision_embed_dim = c.vision_config.hidden_size;
|
||||
let vision_projection = nn::linear_no_bias(
|
||||
vision_embed_dim,
|
||||
c.projection_dim,
|
||||
vs.pp("visual_projection"),
|
||||
)?;
|
||||
|
||||
let text_embed_dim = c.text_config.hidden_size;
|
||||
let text_projection =
|
||||
nn::linear_no_bias(text_embed_dim, c.projection_dim, vs.pp("text_projection"))?;
|
||||
|
||||
let logit_scale = if vs.contains_tensor("logit_scale") {
|
||||
vs.get(&[], "logit_scale")?
|
||||
} else {
|
||||
Tensor::new(&[c.logit_scale_init_value], vs.device())?
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
text_model,
|
||||
vision_model,
|
||||
visual_projection: vision_projection,
|
||||
text_projection,
|
||||
logit_scale,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn get_text_features(
|
||||
&self,
|
||||
input_ids: &Tensor,
|
||||
token_type_ids: Option<&Tensor>,
|
||||
attention_mask: Option<&Tensor>,
|
||||
) -> Result<Tensor> {
|
||||
let output = self
|
||||
.text_model
|
||||
.forward(input_ids, token_type_ids, attention_mask)?;
|
||||
self.text_projection.forward(&output)
|
||||
}
|
||||
|
||||
pub fn get_image_features(&self, pixel_values: &Tensor) -> Result<Tensor> {
|
||||
pixel_values
|
||||
.apply(&self.vision_model)?
|
||||
.apply(&self.visual_projection)
|
||||
}
|
||||
|
||||
pub fn forward(
|
||||
&self,
|
||||
pixel_values: &Tensor,
|
||||
input_ids: &Tensor,
|
||||
token_type_ids: Option<&Tensor>,
|
||||
attention_mask: Option<&Tensor>,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
let image_features = self.get_image_features(pixel_values)?;
|
||||
let text_features = self.get_text_features(input_ids, token_type_ids, attention_mask)?;
|
||||
|
||||
let image_features_normalized = div_l2_norm(&image_features)?;
|
||||
let text_features_normalized = div_l2_norm(&text_features)?;
|
||||
|
||||
let logits_per_text = text_features_normalized.matmul(&image_features_normalized.t()?)?;
|
||||
let logit_scale = self.logit_scale.exp()?;
|
||||
let logits_per_text = logits_per_text.broadcast_mul(&logit_scale)?;
|
||||
let logits_per_image = logits_per_text.t()?;
|
||||
Ok((logits_per_text, logits_per_image))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn div_l2_norm(v: &Tensor) -> Result<Tensor> {
|
||||
let l2_norm = v.sqr()?.sum_keepdim(D::Minus1)?.sqrt()?;
|
||||
v.broadcast_div(&l2_norm)
|
||||
}
|
540
candle-transformers/src/models/chinese_clip/text_model.rs
Normal file
540
candle-transformers/src/models/chinese_clip/text_model.rs
Normal file
@ -0,0 +1,540 @@
|
||||
//! Chinese contrastive Language-Image Pre-Training
|
||||
//!
|
||||
//! Chinese contrastive Language-Image Pre-Training (CLIP) is an architecture trained on
|
||||
//! pairs of images with related texts.
|
||||
//!
|
||||
//! https://github.com/OFA-Sys/Chinese-CLIP
|
||||
//! https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/chinese_clip/modeling_chinese_clip.py
|
||||
|
||||
use candle::{DType, Device, IndexOp, Module, Result, Tensor};
|
||||
use candle_nn as nn;
|
||||
|
||||
use super::Activation;
|
||||
|
||||
/// Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For
|
||||
/// positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to
|
||||
/// [Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155).
|
||||
/// For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models
|
||||
/// with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658).
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum PositionEmbeddingType {
|
||||
Absolute,
|
||||
RelativeKey,
|
||||
RelativeKeyQuery,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ChineseClipTextConfig {
|
||||
pub vocab_size: usize,
|
||||
pub hidden_size: usize,
|
||||
pub num_hidden_layers: usize,
|
||||
pub num_attention_heads: usize,
|
||||
pub intermediate_size: usize,
|
||||
pub hidden_act: Activation,
|
||||
pub hidden_dropout_prob: f32,
|
||||
pub attention_probs_dropout_prob: f64,
|
||||
pub max_position_embeddings: usize,
|
||||
pub type_vocab_size: usize,
|
||||
pub initializer_range: f64,
|
||||
pub initializer_factor: f64,
|
||||
pub layer_norm_eps: f64,
|
||||
pub pad_token_id: usize,
|
||||
pub position_embedding_type: PositionEmbeddingType,
|
||||
pub use_cache: bool,
|
||||
}
|
||||
|
||||
impl Default for ChineseClipTextConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
vocab_size: 30522,
|
||||
hidden_size: 768,
|
||||
num_hidden_layers: 12,
|
||||
num_attention_heads: 12,
|
||||
intermediate_size: 3072,
|
||||
hidden_act: Activation::Gelu,
|
||||
hidden_dropout_prob: 0.1,
|
||||
attention_probs_dropout_prob: 0.1,
|
||||
max_position_embeddings: 512,
|
||||
type_vocab_size: 2,
|
||||
initializer_range: 0.02,
|
||||
initializer_factor: 1.0,
|
||||
layer_norm_eps: 1e-12,
|
||||
pad_token_id: 0,
|
||||
position_embedding_type: PositionEmbeddingType::Absolute,
|
||||
use_cache: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ChineseClipTextConfig {
|
||||
/// referer: https://huggingface.co/OFA-Sys/chinese-clip-vit-base-patch16/blob/main/config.json
|
||||
pub fn clip_vit_base_patch16() -> Self {
|
||||
Self {
|
||||
vocab_size: 21128,
|
||||
hidden_size: 768,
|
||||
num_hidden_layers: 12,
|
||||
num_attention_heads: 12,
|
||||
intermediate_size: 3072,
|
||||
hidden_act: Activation::Gelu,
|
||||
hidden_dropout_prob: 0.1,
|
||||
attention_probs_dropout_prob: 0.1,
|
||||
max_position_embeddings: 512,
|
||||
type_vocab_size: 2,
|
||||
initializer_range: 0.02,
|
||||
initializer_factor: 1.0,
|
||||
layer_norm_eps: 1e-12,
|
||||
pad_token_id: 0,
|
||||
position_embedding_type: PositionEmbeddingType::Absolute,
|
||||
use_cache: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ChineseClipTextEmbeddings {
|
||||
word_embeddings: nn::Embedding,
|
||||
position_embeddings: nn::Embedding,
|
||||
token_type_embeddings: nn::Embedding,
|
||||
layer_norm: nn::LayerNorm,
|
||||
dropout: nn::Dropout,
|
||||
position_embedding_type: PositionEmbeddingType,
|
||||
position_ids: Tensor,
|
||||
token_type_ids: Tensor,
|
||||
}
|
||||
|
||||
impl ChineseClipTextEmbeddings {
|
||||
pub fn new(var: nn::VarBuilder, config: &ChineseClipTextConfig) -> Result<Self> {
|
||||
let word_embeddings = nn::embedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
var.pp("word_embeddings"),
|
||||
)?;
|
||||
let position_embeddings = nn::embedding(
|
||||
config.max_position_embeddings,
|
||||
config.hidden_size,
|
||||
var.pp("position_embeddings"),
|
||||
)?;
|
||||
let token_type_embeddings = nn::embedding(
|
||||
config.type_vocab_size,
|
||||
config.hidden_size,
|
||||
var.pp("token_type_embeddings"),
|
||||
)?;
|
||||
let layer_norm = nn::layer_norm::<f64>(
|
||||
config.hidden_size,
|
||||
config.layer_norm_eps,
|
||||
var.pp("LayerNorm"),
|
||||
)?;
|
||||
let dropout = nn::Dropout::new(config.hidden_dropout_prob);
|
||||
let position_ids =
|
||||
Tensor::arange(0u32, config.max_position_embeddings as u32, var.device())?
|
||||
.unsqueeze(0)?;
|
||||
let token_type_ids = Tensor::zeros(position_ids.shape(), DType::I64, var.device())?;
|
||||
|
||||
Ok(Self {
|
||||
word_embeddings,
|
||||
position_embeddings,
|
||||
token_type_embeddings,
|
||||
layer_norm,
|
||||
dropout,
|
||||
position_embedding_type: config.position_embedding_type.clone(),
|
||||
position_ids,
|
||||
token_type_ids,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor, token_type_ids: Option<&Tensor>) -> Result<Tensor> {
|
||||
let (_batch_size, seq_length) = xs.dims2()?;
|
||||
let position_ids = (0..seq_length as u32).collect::<Vec<_>>();
|
||||
let position_ids = self.position_ids.index_select(
|
||||
&Tensor::new(&position_ids[..], self.position_ids.device())?,
|
||||
1,
|
||||
)?;
|
||||
|
||||
let word_embeddings = self.word_embeddings.forward(xs)?;
|
||||
|
||||
let token_type_ids = match token_type_ids {
|
||||
Some(token_type_ids) => token_type_ids,
|
||||
None => &self.token_type_ids.i((.., 0..seq_length))?,
|
||||
};
|
||||
let token_type_ids = token_type_ids.expand(xs.shape())?;
|
||||
let token_type_embeddings = self.token_type_embeddings.forward(&token_type_ids)?;
|
||||
|
||||
let embeddings = (&word_embeddings + token_type_embeddings)?;
|
||||
let embeddings = match self.position_embedding_type {
|
||||
PositionEmbeddingType::Absolute => {
|
||||
let position_embeddings = self.position_embeddings.forward(&position_ids)?;
|
||||
let position_embeddings = position_embeddings.expand(embeddings.shape())?;
|
||||
(embeddings + position_embeddings)?
|
||||
}
|
||||
_ => embeddings,
|
||||
};
|
||||
let embeddings = self.layer_norm.forward(&embeddings)?;
|
||||
let embeddings = self.dropout.forward(&embeddings, false)?;
|
||||
Ok(embeddings)
|
||||
}
|
||||
}
|
||||
|
||||
/// Copied from [`crate::models::bert::BertSelfOutput`] to [`ChineseClipTextSelfOutput`]
|
||||
#[derive(Clone, Debug)]
|
||||
struct ChineseClipTextSelfOutput {
|
||||
dense: nn::Linear,
|
||||
layer_norm: nn::LayerNorm,
|
||||
dropout: nn::Dropout,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl ChineseClipTextSelfOutput {
|
||||
fn new(var: nn::VarBuilder, config: &ChineseClipTextConfig) -> Result<Self> {
|
||||
let dense = nn::linear(config.hidden_size, config.hidden_size, var.pp("dense"))?;
|
||||
let layer_norm = nn::layer_norm(
|
||||
config.hidden_size,
|
||||
config.layer_norm_eps,
|
||||
var.pp("LayerNorm"),
|
||||
)?;
|
||||
let dropout = nn::Dropout::new(config.hidden_dropout_prob);
|
||||
Ok(Self {
|
||||
dense,
|
||||
layer_norm,
|
||||
dropout,
|
||||
span: tracing::span!(tracing::Level::TRACE, "self-out"),
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let hidden_states = self.dense.forward(hidden_states)?;
|
||||
let hidden_states = self.dropout.forward(&hidden_states, false)?;
|
||||
self.layer_norm.forward(&(hidden_states + input_tensor)?)
|
||||
}
|
||||
}
|
||||
|
||||
/// Copied from [`crate::models::bert::BertSelfAttention`] to [`ChineseClipTextSelfAttention`]
|
||||
#[derive(Clone, Debug)]
|
||||
struct ChineseClipTextSelfAttention {
|
||||
query: nn::Linear,
|
||||
key: nn::Linear,
|
||||
value: nn::Linear,
|
||||
dropout: nn::Dropout,
|
||||
num_attention_heads: usize,
|
||||
attention_head_size: usize,
|
||||
span: tracing::Span,
|
||||
span_softmax: tracing::Span,
|
||||
}
|
||||
|
||||
impl ChineseClipTextSelfAttention {
|
||||
fn new(var: nn::VarBuilder, config: &ChineseClipTextConfig) -> Result<Self> {
|
||||
let attention_head_size = config.hidden_size / config.num_attention_heads;
|
||||
let all_head_size = config.num_attention_heads * attention_head_size;
|
||||
let dropout = nn::Dropout::new(config.hidden_dropout_prob);
|
||||
let hidden_size = config.hidden_size;
|
||||
let query = nn::linear(hidden_size, all_head_size, var.pp("query"))?;
|
||||
let value = nn::linear(hidden_size, all_head_size, var.pp("value"))?;
|
||||
let key = nn::linear(hidden_size, all_head_size, var.pp("key"))?;
|
||||
Ok(Self {
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
dropout,
|
||||
num_attention_heads: config.num_attention_heads,
|
||||
attention_head_size,
|
||||
span: tracing::span!(tracing::Level::TRACE, "self-attn"),
|
||||
span_softmax: tracing::span!(tracing::Level::TRACE, "softmax"),
|
||||
})
|
||||
}
|
||||
|
||||
fn transpose_for_scores(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let mut new_x_shape = xs.dims().to_vec();
|
||||
new_x_shape.pop();
|
||||
new_x_shape.push(self.num_attention_heads);
|
||||
new_x_shape.push(self.attention_head_size);
|
||||
let xs = xs.reshape(new_x_shape.as_slice())?.transpose(1, 2)?;
|
||||
xs.contiguous()
|
||||
}
|
||||
|
||||
fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let query_layer = self.query.forward(hidden_states)?;
|
||||
let key_layer = self.key.forward(hidden_states)?;
|
||||
let value_layer = self.value.forward(hidden_states)?;
|
||||
|
||||
let query_layer = self.transpose_for_scores(&query_layer)?;
|
||||
let key_layer = self.transpose_for_scores(&key_layer)?;
|
||||
let value_layer = self.transpose_for_scores(&value_layer)?;
|
||||
|
||||
let attention_scores = query_layer.matmul(&key_layer.t()?)?;
|
||||
let attention_scores = (attention_scores / (self.attention_head_size as f64).sqrt())?;
|
||||
let attention_scores = attention_scores.broadcast_add(attention_mask)?;
|
||||
let attention_probs = {
|
||||
let _enter_sm = self.span_softmax.enter();
|
||||
nn::ops::softmax(&attention_scores, candle::D::Minus1)?
|
||||
};
|
||||
let attention_probs = self.dropout.forward(&attention_probs, false)?;
|
||||
|
||||
let context_layer = attention_probs.matmul(&value_layer)?;
|
||||
let context_layer = context_layer.transpose(1, 2)?.contiguous()?;
|
||||
let context_layer = context_layer.flatten_from(candle::D::Minus2)?;
|
||||
Ok(context_layer)
|
||||
}
|
||||
}
|
||||
|
||||
/// Copied from [`crate::models::bert::BertAttention`] to [`ChineseClipTextAttention`]
|
||||
#[derive(Clone, Debug)]
|
||||
struct ChineseClipTextAttention {
|
||||
self_attention: ChineseClipTextSelfAttention,
|
||||
self_output: ChineseClipTextSelfOutput,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl ChineseClipTextAttention {
|
||||
fn new(var: nn::VarBuilder, config: &ChineseClipTextConfig) -> Result<Self> {
|
||||
let self_attention = ChineseClipTextSelfAttention::new(var.pp("self"), config)?;
|
||||
let self_output = ChineseClipTextSelfOutput::new(var.pp("output"), config)?;
|
||||
Ok(Self {
|
||||
self_attention,
|
||||
self_output,
|
||||
span: tracing::span!(tracing::Level::TRACE, "attn"),
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let self_outputs = self.self_attention.forward(hidden_states, attention_mask)?;
|
||||
let attention_output = self.self_output.forward(&self_outputs, hidden_states)?;
|
||||
Ok(attention_output)
|
||||
}
|
||||
}
|
||||
|
||||
type HiddenActLayer = Activation;
|
||||
|
||||
/// Copied from [`crate::models::bert::BertIntermediate`] to [`ChineseClipTextIntermediate`]
|
||||
#[derive(Clone, Debug)]
|
||||
struct ChineseClipTextIntermediate {
|
||||
dense: nn::Linear,
|
||||
intermediate_act: HiddenActLayer,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl ChineseClipTextIntermediate {
|
||||
fn new(var: nn::VarBuilder, config: &ChineseClipTextConfig) -> Result<Self> {
|
||||
let dense = nn::linear(
|
||||
config.hidden_size,
|
||||
config.intermediate_size,
|
||||
var.pp("dense"),
|
||||
)?;
|
||||
Ok(Self {
|
||||
dense,
|
||||
intermediate_act: config.hidden_act,
|
||||
span: tracing::span!(tracing::Level::TRACE, "inter"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for ChineseClipTextIntermediate {
|
||||
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let hidden_states = self.dense.forward(hidden_states)?;
|
||||
let ys = self.intermediate_act.forward(&hidden_states)?;
|
||||
Ok(ys)
|
||||
}
|
||||
}
|
||||
|
||||
/// Copied from [`crate::models::bert::BertOutput`] to [`ChineseClipTextOutput`]
|
||||
#[derive(Clone, Debug)]
|
||||
struct ChineseClipTextOutput {
|
||||
dense: nn::Linear,
|
||||
layer_norm: nn::LayerNorm,
|
||||
dropout: nn::Dropout,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl ChineseClipTextOutput {
|
||||
fn new(var: nn::VarBuilder, config: &ChineseClipTextConfig) -> Result<Self> {
|
||||
let dense = nn::linear(
|
||||
config.intermediate_size,
|
||||
config.hidden_size,
|
||||
var.pp("dense"),
|
||||
)?;
|
||||
let layer_norm = nn::layer_norm(
|
||||
config.hidden_size,
|
||||
config.layer_norm_eps,
|
||||
var.pp("LayerNorm"),
|
||||
)?;
|
||||
let dropout = nn::Dropout::new(config.hidden_dropout_prob);
|
||||
Ok(Self {
|
||||
dense,
|
||||
layer_norm,
|
||||
dropout,
|
||||
span: tracing::span!(tracing::Level::TRACE, "out"),
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let hidden_states = self.dense.forward(hidden_states)?;
|
||||
let hidden_states = self.dropout.forward(&hidden_states, false)?;
|
||||
self.layer_norm.forward(&(hidden_states + input_tensor)?)
|
||||
}
|
||||
}
|
||||
|
||||
/// Copied from [`crate::models::bert::BertLayer`] to [`ChineseClipTextLayer`]
|
||||
#[derive(Clone, Debug)]
|
||||
struct ChineseClipTextLayer {
|
||||
attention: ChineseClipTextAttention,
|
||||
intermediate: ChineseClipTextIntermediate,
|
||||
output: ChineseClipTextOutput,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl ChineseClipTextLayer {
|
||||
fn new(var: nn::VarBuilder, config: &ChineseClipTextConfig) -> Result<Self> {
|
||||
let attention = ChineseClipTextAttention::new(var.pp("attention"), config)?;
|
||||
let intermediate = ChineseClipTextIntermediate::new(var.pp("intermediate"), config)?;
|
||||
let output = ChineseClipTextOutput::new(var.pp("output"), config)?;
|
||||
Ok(Self {
|
||||
attention,
|
||||
intermediate,
|
||||
output,
|
||||
span: tracing::span!(tracing::Level::TRACE, "layer"),
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let attention_output = self.attention.forward(hidden_states, attention_mask)?;
|
||||
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L523
|
||||
let intermediate_output = self.intermediate.forward(&attention_output)?;
|
||||
let layer_output = self
|
||||
.output
|
||||
.forward(&intermediate_output, &attention_output)?;
|
||||
Ok(layer_output)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct Tanh;
|
||||
|
||||
impl Tanh {
|
||||
pub fn new() -> Self {
|
||||
Self {}
|
||||
}
|
||||
}
|
||||
impl Module for Tanh {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
xs.tanh()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct ChineseClipTextPooler {
|
||||
dense: nn::Linear,
|
||||
activation: Tanh,
|
||||
}
|
||||
|
||||
impl ChineseClipTextPooler {
|
||||
pub fn new(var: nn::VarBuilder, config: &ChineseClipTextConfig) -> Result<Self> {
|
||||
let dense = nn::linear(config.hidden_size, config.hidden_size, var.pp("dense"))?;
|
||||
let activation = Tanh::new();
|
||||
Ok(Self { dense, activation })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for ChineseClipTextPooler {
|
||||
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
||||
let first_token_tensor = hidden_states.i((.., 0))?;
|
||||
let pooled_output = self.dense.forward(&first_token_tensor)?;
|
||||
let pooled_output = self.activation.forward(&pooled_output)?;
|
||||
Ok(pooled_output)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct ChineseClipTextEncoder {
|
||||
layers: Vec<ChineseClipTextLayer>,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl ChineseClipTextEncoder {
|
||||
fn new(var: nn::VarBuilder, config: &ChineseClipTextConfig) -> Result<Self> {
|
||||
let layers = (0..config.num_hidden_layers)
|
||||
.map(|index| ChineseClipTextLayer::new(var.pp(format!("layer.{index}")), config))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "encoder");
|
||||
Ok(ChineseClipTextEncoder { layers, span })
|
||||
}
|
||||
|
||||
fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let mut hidden_states = hidden_states.clone();
|
||||
// Use a loop rather than a fold as it's easier to modify when adding debug/...
|
||||
for layer in self.layers.iter() {
|
||||
hidden_states = layer.forward(&hidden_states, attention_mask)?
|
||||
}
|
||||
Ok(hidden_states)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ChineseClipTextTransformer {
|
||||
embeddings: ChineseClipTextEmbeddings,
|
||||
encoder: ChineseClipTextEncoder,
|
||||
pooler: Option<ChineseClipTextPooler>,
|
||||
pub device: Device,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl ChineseClipTextTransformer {
|
||||
pub fn new(var: nn::VarBuilder, config: &ChineseClipTextConfig) -> Result<Self> {
|
||||
let embeddings = ChineseClipTextEmbeddings::new(var.pp("embeddings"), config)?;
|
||||
let encoder = ChineseClipTextEncoder::new(var.pp("encoder"), config)?;
|
||||
// see: https://github.com/huggingface/transformers/blob/e40bb4845e0eefb52ec1e9cac9c2446ab36aef81/src/transformers/models/chinese_clip/modeling_chinese_clip.py#L1362
|
||||
// In the original Python version of the code, the pooler is not used, and there are no parameters for the pooler in the weight file.
|
||||
let pooler = if var.contains_tensor("pooler") {
|
||||
Some(ChineseClipTextPooler::new(var.pp("pooler"), config)?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(Self {
|
||||
embeddings,
|
||||
encoder,
|
||||
pooler,
|
||||
device: var.device().clone(),
|
||||
span: tracing::span!(tracing::Level::TRACE, "model"),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(
|
||||
&self,
|
||||
input_ids: &Tensor,
|
||||
token_type_ids: Option<&Tensor>,
|
||||
attention_mask: Option<&Tensor>,
|
||||
) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let embedding_output = self.embeddings.forward(input_ids, token_type_ids)?;
|
||||
let attention_mask = match attention_mask {
|
||||
Some(attention_mask) => attention_mask.clone(),
|
||||
None => input_ids.ones_like()?,
|
||||
};
|
||||
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L995
|
||||
let attention_mask = get_extended_attention_mask(&attention_mask, DType::F32)?;
|
||||
let encoder_outputs = self.encoder.forward(&embedding_output, &attention_mask)?;
|
||||
let encoder_output = encoder_outputs.i((.., 0, ..))?;
|
||||
let pooled_output = match &self.pooler {
|
||||
Some(pooler) => pooler.forward(&encoder_output)?,
|
||||
None => encoder_output,
|
||||
};
|
||||
|
||||
Ok(pooled_output)
|
||||
}
|
||||
}
|
||||
|
||||
fn get_extended_attention_mask(attention_mask: &Tensor, dtype: DType) -> Result<Tensor> {
|
||||
let attention_mask = match attention_mask.rank() {
|
||||
3 => attention_mask.unsqueeze(1)?,
|
||||
2 => attention_mask.unsqueeze(1)?.unsqueeze(1)?,
|
||||
_ => candle::bail!("Wrong shape for input_ids or attention_mask"),
|
||||
};
|
||||
let attention_mask = attention_mask.to_dtype(dtype)?;
|
||||
// torch.finfo(dtype).min
|
||||
(attention_mask.ones_like()? - &attention_mask)?
|
||||
.broadcast_mul(&Tensor::try_from(f32::MIN)?.to_device(attention_mask.device())?)
|
||||
}
|
385
candle-transformers/src/models/chinese_clip/vision_model.rs
Normal file
385
candle-transformers/src/models/chinese_clip/vision_model.rs
Normal file
@ -0,0 +1,385 @@
|
||||
//! Chinese contrastive Language-Image Pre-Training
|
||||
//!
|
||||
//! Chinese contrastive Language-Image Pre-Training (CLIP) is an architecture trained on
|
||||
//! pairs of images with related texts.
|
||||
//!
|
||||
//! https://github.com/OFA-Sys/Chinese-CLIP
|
||||
//! https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/chinese_clip/modeling_chinese_clip.py
|
||||
|
||||
use candle::{DType, IndexOp, Module, Result, Shape, Tensor, D};
|
||||
use candle_nn as nn;
|
||||
|
||||
use super::{Activation, EncoderConfig};
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ChineseClipVisionConfig {
|
||||
pub hidden_size: usize,
|
||||
pub intermediate_size: usize,
|
||||
pub projection_dim: usize,
|
||||
pub num_hidden_layers: usize,
|
||||
pub num_attention_heads: usize,
|
||||
pub num_channels: usize,
|
||||
pub image_size: usize,
|
||||
pub patch_size: usize,
|
||||
pub hidden_act: Activation,
|
||||
pub layer_norm_eps: f64,
|
||||
pub attention_dropout: f32,
|
||||
pub initializer_range: f32,
|
||||
pub initializer_factor: f32,
|
||||
}
|
||||
|
||||
impl Default for ChineseClipVisionConfig {
|
||||
fn default() -> Self {
|
||||
ChineseClipVisionConfig {
|
||||
hidden_size: 768,
|
||||
intermediate_size: 3072,
|
||||
projection_dim: 512,
|
||||
num_hidden_layers: 12,
|
||||
num_attention_heads: 12,
|
||||
num_channels: 3,
|
||||
image_size: 224,
|
||||
patch_size: 32,
|
||||
hidden_act: Activation::QuickGelu,
|
||||
layer_norm_eps: 1e-5,
|
||||
attention_dropout: 0.0,
|
||||
initializer_range: 0.02,
|
||||
initializer_factor: 1.0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ChineseClipVisionConfig {
|
||||
/// referer: https://huggingface.co/OFA-Sys/chinese-clip-vit-base-patch16/blob/main/config.json
|
||||
pub fn clip_vit_base_patch16() -> Self {
|
||||
Self {
|
||||
hidden_size: 768,
|
||||
intermediate_size: 3072,
|
||||
projection_dim: 512,
|
||||
num_hidden_layers: 12,
|
||||
num_attention_heads: 12,
|
||||
num_channels: 3,
|
||||
image_size: 224,
|
||||
patch_size: 16,
|
||||
hidden_act: Activation::QuickGelu,
|
||||
layer_norm_eps: 1e-5,
|
||||
attention_dropout: 0.0,
|
||||
initializer_range: 0.02,
|
||||
initializer_factor: 1.0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ChineseClipVisionEmbeddings {
|
||||
patch_embedding: nn::Conv2d,
|
||||
position_ids: Tensor,
|
||||
class_embedding: Tensor,
|
||||
position_embedding: nn::Embedding,
|
||||
}
|
||||
|
||||
impl ChineseClipVisionEmbeddings {
|
||||
pub fn new(var: nn::VarBuilder, config: &ChineseClipVisionConfig) -> Result<Self> {
|
||||
let embed_dim = config.hidden_size;
|
||||
// originally nn.Parameter
|
||||
let class_embedding = if var.contains_tensor("class_embedding") {
|
||||
var.get(embed_dim, "class_embedding")?
|
||||
} else {
|
||||
Tensor::randn(0f32, 1f32, embed_dim, var.device())?
|
||||
};
|
||||
|
||||
let num_patches = (config.image_size / config.patch_size).pow(2);
|
||||
let num_positions = num_patches + 1;
|
||||
let position_ids = Tensor::arange(0, num_positions as i64, var.device())?;
|
||||
|
||||
let conv2dconfig = nn::Conv2dConfig {
|
||||
stride: config.patch_size,
|
||||
..Default::default()
|
||||
};
|
||||
let position_embedding =
|
||||
nn::embedding(num_positions, embed_dim, var.pp("position_embedding"))?;
|
||||
let patch_embedding = nn::conv2d_no_bias(
|
||||
config.num_channels,
|
||||
embed_dim,
|
||||
config.patch_size,
|
||||
conv2dconfig,
|
||||
var.pp("patch_embedding"),
|
||||
)?;
|
||||
Ok(Self {
|
||||
patch_embedding,
|
||||
position_ids,
|
||||
class_embedding,
|
||||
position_embedding,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for ChineseClipVisionEmbeddings {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let batch_size = xs.shape().dims();
|
||||
let patch_embeds = self
|
||||
.patch_embedding
|
||||
.forward(xs)?
|
||||
.flatten_from(2)?
|
||||
.transpose(1, 2)?;
|
||||
let shape = Shape::from((batch_size[0], 1, self.class_embedding.dim(D::Minus1)?));
|
||||
let class_embeds = self.class_embedding.expand(shape)?;
|
||||
let embeddings = Tensor::cat(&[class_embeds, patch_embeds], 1)?;
|
||||
let position_embedding = self.position_embedding.forward(&self.position_ids)?;
|
||||
embeddings.broadcast_add(&position_embedding)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct ChineseClipVisionAttention {
|
||||
k_proj: nn::Linear,
|
||||
v_proj: nn::Linear,
|
||||
q_proj: nn::Linear,
|
||||
out_proj: nn::Linear,
|
||||
head_dim: usize,
|
||||
scale: f64,
|
||||
num_attention_heads: usize,
|
||||
}
|
||||
|
||||
impl ChineseClipVisionAttention {
|
||||
fn new(var: nn::VarBuilder, config: &EncoderConfig) -> Result<Self> {
|
||||
let embed_dim = config.embed_dim();
|
||||
let num_attention_heads = config.num_attention_heads();
|
||||
let k_proj = nn::linear(embed_dim, embed_dim, var.pp("k_proj"))?;
|
||||
let v_proj = nn::linear(embed_dim, embed_dim, var.pp("v_proj"))?;
|
||||
let q_proj = nn::linear(embed_dim, embed_dim, var.pp("q_proj"))?;
|
||||
let out_proj = nn::linear(embed_dim, embed_dim, var.pp("out_proj"))?;
|
||||
let head_dim = embed_dim / num_attention_heads;
|
||||
let scale = (head_dim as f64).powf(-0.5);
|
||||
|
||||
Ok(ChineseClipVisionAttention {
|
||||
k_proj,
|
||||
v_proj,
|
||||
q_proj,
|
||||
out_proj,
|
||||
head_dim,
|
||||
scale,
|
||||
num_attention_heads,
|
||||
})
|
||||
}
|
||||
|
||||
fn shape(&self, xs: &Tensor, seq_len: usize, bsz: usize) -> Result<Tensor> {
|
||||
xs.reshape((bsz, seq_len, self.num_attention_heads, self.head_dim))?
|
||||
.transpose(1, 2)?
|
||||
.contiguous()
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor, causal_attention_mask: Option<&Tensor>) -> Result<Tensor> {
|
||||
let in_dtype = xs.dtype();
|
||||
let (bsz, seq_len, embed_dim) = xs.dims3()?;
|
||||
|
||||
let proj_shape = (bsz * self.num_attention_heads, seq_len, self.head_dim);
|
||||
let query_states = self
|
||||
.shape(&(self.q_proj.forward(xs)? * self.scale)?, seq_len, bsz)?
|
||||
.reshape(proj_shape)?
|
||||
.to_dtype(DType::F32)?;
|
||||
let key_states = self
|
||||
.shape(&self.k_proj.forward(xs)?, seq_len, bsz)?
|
||||
.reshape(proj_shape)?
|
||||
.to_dtype(DType::F32)?;
|
||||
let value_states = self
|
||||
.shape(&self.v_proj.forward(xs)?, seq_len, bsz)?
|
||||
.reshape(proj_shape)?
|
||||
.to_dtype(DType::F32)?;
|
||||
|
||||
let attn_weights = query_states.matmul(&key_states.transpose(1, 2)?)?;
|
||||
|
||||
let src_len = key_states.dim(1)?;
|
||||
|
||||
let attn_weights = if let Some(causal_attention_mask) = causal_attention_mask {
|
||||
attn_weights
|
||||
.reshape((bsz, self.num_attention_heads, seq_len, src_len))?
|
||||
.broadcast_add(causal_attention_mask)?
|
||||
.reshape((bsz * self.num_attention_heads, seq_len, src_len))?
|
||||
} else {
|
||||
attn_weights
|
||||
};
|
||||
|
||||
let attn_weights = nn::ops::softmax(&attn_weights, D::Minus1)?;
|
||||
|
||||
let attn_output = attn_weights.matmul(&value_states)?.to_dtype(in_dtype)?;
|
||||
let attn_output = attn_output
|
||||
.reshape((bsz, self.num_attention_heads, seq_len, self.head_dim))?
|
||||
.transpose(1, 2)?
|
||||
.reshape((bsz, seq_len, embed_dim))?;
|
||||
self.out_proj.forward(&attn_output)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct ChineseClipVisionMlp {
|
||||
fc1: nn::Linear,
|
||||
fc2: nn::Linear,
|
||||
activation: Activation,
|
||||
}
|
||||
|
||||
impl ChineseClipVisionMlp {
|
||||
fn new(var: nn::VarBuilder, config: &EncoderConfig) -> Result<Self> {
|
||||
let fc1 = nn::linear(
|
||||
config.embed_dim(),
|
||||
config.intermediate_size(),
|
||||
var.pp("fc1"),
|
||||
)?;
|
||||
let fc2 = nn::linear(
|
||||
config.intermediate_size(),
|
||||
config.embed_dim(),
|
||||
var.pp("fc2"),
|
||||
)?;
|
||||
|
||||
Ok(ChineseClipVisionMlp {
|
||||
fc1,
|
||||
fc2,
|
||||
activation: config.activation(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl ChineseClipVisionMlp {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let xs = self.fc1.forward(xs)?;
|
||||
self.fc2.forward(&self.activation.forward(&xs)?)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct ChineseClipVisionEncoderLayer {
|
||||
self_attn: ChineseClipVisionAttention,
|
||||
layer_norm1: nn::LayerNorm,
|
||||
mlp: ChineseClipVisionMlp,
|
||||
layer_norm2: nn::LayerNorm,
|
||||
}
|
||||
|
||||
impl ChineseClipVisionEncoderLayer {
|
||||
fn new(var: nn::VarBuilder, config: &EncoderConfig) -> Result<Self> {
|
||||
let self_attn = ChineseClipVisionAttention::new(var.pp("self_attn"), config)?;
|
||||
let layer_norm1 = nn::layer_norm(
|
||||
config.embed_dim(),
|
||||
config.layer_norm_eps(),
|
||||
var.pp("layer_norm1"),
|
||||
)?;
|
||||
let mlp = ChineseClipVisionMlp::new(var.pp("mlp"), config)?;
|
||||
let layer_norm2 = nn::layer_norm(
|
||||
config.embed_dim(),
|
||||
config.layer_norm_eps(),
|
||||
var.pp("layer_norm2"),
|
||||
)?;
|
||||
|
||||
Ok(ChineseClipVisionEncoderLayer {
|
||||
self_attn,
|
||||
layer_norm1,
|
||||
mlp,
|
||||
layer_norm2,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor, causal_attention_mask: Option<&Tensor>) -> Result<Tensor> {
|
||||
let residual = xs;
|
||||
let xs = self.layer_norm1.forward(xs)?;
|
||||
let xs = self.self_attn.forward(&xs, causal_attention_mask)?;
|
||||
let xs = (xs + residual)?;
|
||||
|
||||
let residual = &xs;
|
||||
let xs = self.layer_norm2.forward(&xs)?;
|
||||
let xs = self.mlp.forward(&xs)?;
|
||||
xs + residual
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ChineseClipVisionEncoder {
|
||||
layers: Vec<ChineseClipVisionEncoderLayer>,
|
||||
}
|
||||
|
||||
impl ChineseClipVisionEncoder {
|
||||
pub fn new(var: nn::VarBuilder, config: &EncoderConfig) -> Result<Self> {
|
||||
let vs = var.pp("layers");
|
||||
let mut layers: Vec<ChineseClipVisionEncoderLayer> = Vec::new();
|
||||
for index in 0..config.num_hidden_layers() {
|
||||
let layer = ChineseClipVisionEncoderLayer::new(vs.pp(index.to_string()), config)?;
|
||||
layers.push(layer)
|
||||
}
|
||||
Ok(ChineseClipVisionEncoder { layers })
|
||||
}
|
||||
|
||||
pub fn forward(&self, xs: &Tensor, causal_attention_mask: Option<&Tensor>) -> Result<Tensor> {
|
||||
let mut xs = xs.clone();
|
||||
for layer in self.layers.iter() {
|
||||
xs = layer.forward(&xs, causal_attention_mask)?;
|
||||
}
|
||||
Ok(xs)
|
||||
}
|
||||
|
||||
// required by LLaVA
|
||||
pub fn output_hidden_states(
|
||||
&self,
|
||||
xs: &Tensor,
|
||||
causal_attention_mask: Option<&Tensor>,
|
||||
) -> Result<Vec<Tensor>> {
|
||||
let mut xs = xs.clone();
|
||||
let mut hidden_states = Vec::new();
|
||||
for layer in self.layers.iter() {
|
||||
xs = layer.forward(&xs, causal_attention_mask)?;
|
||||
hidden_states.push(xs.clone());
|
||||
}
|
||||
Ok(hidden_states)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ChineseClipVisionTransformer {
|
||||
embeddings: ChineseClipVisionEmbeddings,
|
||||
encoder: ChineseClipVisionEncoder,
|
||||
pre_layer_norm: nn::LayerNorm,
|
||||
final_layer_norm: nn::LayerNorm,
|
||||
}
|
||||
|
||||
impl ChineseClipVisionTransformer {
|
||||
pub fn new(var: nn::VarBuilder, config: &ChineseClipVisionConfig) -> Result<Self> {
|
||||
let embed_dim = config.hidden_size;
|
||||
let embeddings = ChineseClipVisionEmbeddings::new(var.pp("embeddings"), config)?;
|
||||
let pre_layer_norm =
|
||||
nn::layer_norm(embed_dim, config.layer_norm_eps, var.pp("pre_layrnorm"))?;
|
||||
let encoder = ChineseClipVisionEncoder::new(
|
||||
var.pp("encoder"),
|
||||
&EncoderConfig::Vision(config.clone()),
|
||||
)?;
|
||||
let final_layer_norm =
|
||||
nn::layer_norm(embed_dim, config.layer_norm_eps, var.pp("post_layernorm"))?;
|
||||
Ok(Self {
|
||||
embeddings,
|
||||
encoder,
|
||||
final_layer_norm,
|
||||
pre_layer_norm,
|
||||
})
|
||||
}
|
||||
// required by LLaVA
|
||||
pub fn output_hidden_states(&self, pixel_values: &Tensor) -> Result<Vec<Tensor>> {
|
||||
let hidden_states = pixel_values
|
||||
.apply(&self.embeddings)?
|
||||
.apply(&self.pre_layer_norm)?;
|
||||
|
||||
let mut result = self.encoder.output_hidden_states(&hidden_states, None)?;
|
||||
let encoder_outputs = result.last().unwrap();
|
||||
let pooled_output = encoder_outputs.i((.., 0, ..))?;
|
||||
result.push(self.final_layer_norm.forward(&pooled_output)?.clone());
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for ChineseClipVisionTransformer {
|
||||
fn forward(&self, pixel_values: &Tensor) -> Result<Tensor> {
|
||||
let hidden_states = pixel_values
|
||||
.apply(&self.embeddings)?
|
||||
.apply(&self.pre_layer_norm)?;
|
||||
|
||||
let encoder_outputs = self.encoder.forward(&hidden_states, None)?;
|
||||
|
||||
// referer: https://github.com/huggingface/transformers/blob/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip/modeling_clip.py#L787
|
||||
let pooled_output = encoder_outputs.i((.., 0, ..))?;
|
||||
self.final_layer_norm.forward(&pooled_output)
|
||||
}
|
||||
}
|
@ -5,6 +5,7 @@ pub mod bigcode;
|
||||
pub mod blip;
|
||||
pub mod blip_text;
|
||||
pub mod chatglm;
|
||||
pub mod chinese_clip;
|
||||
pub mod clip;
|
||||
pub mod codegeex4_9b;
|
||||
pub mod colpali;
|
||||
|
Reference in New Issue
Block a user