mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Fix clippy lints + minor cleanups. (#1957)
* Fix clippy lints + minor cleanups. * fmt. * Derive clone.
This commit is contained in:
@ -10,13 +10,11 @@ use self::{
|
|||||||
vision_model::ClipVisionTransformer,
|
vision_model::ClipVisionTransformer,
|
||||||
};
|
};
|
||||||
use candle::{Result, Tensor, D};
|
use candle::{Result, Tensor, D};
|
||||||
use candle_nn::Module;
|
|
||||||
|
|
||||||
use tracing::warn;
|
|
||||||
|
|
||||||
pub mod text_model;
|
pub mod text_model;
|
||||||
pub mod vision_model;
|
pub mod vision_model;
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
pub struct ClipModel {
|
pub struct ClipModel {
|
||||||
text_model: ClipTextTransformer,
|
text_model: ClipTextTransformer,
|
||||||
vision_model: ClipVisionTransformer,
|
vision_model: ClipVisionTransformer,
|
||||||
@ -25,6 +23,7 @@ pub struct ClipModel {
|
|||||||
logit_scale: Tensor,
|
logit_scale: Tensor,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
pub enum EncoderConfig {
|
pub enum EncoderConfig {
|
||||||
Text(text_model::ClipTextConfig),
|
Text(text_model::ClipTextConfig),
|
||||||
Vision(vision_model::ClipVisionConfig),
|
Vision(vision_model::ClipVisionConfig),
|
||||||
@ -67,6 +66,7 @@ impl EncoderConfig {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
pub struct ClipConfig {
|
pub struct ClipConfig {
|
||||||
pub text_config: text_model::ClipTextConfig,
|
pub text_config: text_model::ClipTextConfig,
|
||||||
pub vision_config: vision_model::ClipVisionConfig,
|
pub vision_config: vision_model::ClipVisionConfig,
|
||||||
@ -111,7 +111,6 @@ impl ClipModel {
|
|||||||
let logit_scale = if vs.contains_tensor("logit_scale") {
|
let logit_scale = if vs.contains_tensor("logit_scale") {
|
||||||
vs.get(&[], "logit_scale")?
|
vs.get(&[], "logit_scale")?
|
||||||
} else {
|
} else {
|
||||||
warn!("Creating logit_scale tensor, results may vary.");
|
|
||||||
Tensor::new(&[c.logit_scale_init_value], vs.device())?
|
Tensor::new(&[c.logit_scale_init_value], vs.device())?
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -125,38 +124,26 @@ impl ClipModel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_text_features(&self, input_ids: &Tensor) -> Result<Tensor> {
|
pub fn get_text_features(&self, input_ids: &Tensor) -> Result<Tensor> {
|
||||||
let text_outputs = self.text_model.forward(input_ids)?;
|
input_ids
|
||||||
|
.apply(&self.text_model)?
|
||||||
let text_features = self.text_projection.forward(&text_outputs)?;
|
.apply(&self.text_projection)
|
||||||
|
|
||||||
Ok(text_features)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_image_features(&self, pixel_values: &Tensor) -> Result<Tensor> {
|
pub fn get_image_features(&self, pixel_values: &Tensor) -> Result<Tensor> {
|
||||||
let image_features = self.vision_model.forward(pixel_values)?;
|
pixel_values
|
||||||
|
.apply(&self.vision_model)?
|
||||||
let image_features = self.visual_projection.forward(&image_features)?;
|
.apply(&self.visual_projection)
|
||||||
|
|
||||||
Ok(image_features)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn forward(&self, pixel_values: &Tensor, input_ids: &Tensor) -> Result<(Tensor, Tensor)> {
|
pub fn forward(&self, pixel_values: &Tensor, input_ids: &Tensor) -> Result<(Tensor, Tensor)> {
|
||||||
let image_features = self.get_image_features(pixel_values)?;
|
let image_features = self.get_image_features(pixel_values)?;
|
||||||
|
|
||||||
let text_features = self.get_text_features(input_ids)?;
|
let text_features = self.get_text_features(input_ids)?;
|
||||||
|
|
||||||
let image_features_normalized = div_l2_norm(&image_features)?;
|
let image_features_normalized = div_l2_norm(&image_features)?;
|
||||||
|
|
||||||
let text_features_normalized = div_l2_norm(&text_features)?;
|
let text_features_normalized = div_l2_norm(&text_features)?;
|
||||||
|
|
||||||
let logits_per_text = text_features_normalized.matmul(&image_features_normalized.t()?)?;
|
let logits_per_text = text_features_normalized.matmul(&image_features_normalized.t()?)?;
|
||||||
|
let logit_scale = self.logit_scale.exp()?;
|
||||||
let logit_scale = &self.logit_scale.exp()?;
|
|
||||||
|
|
||||||
let logits_per_text = logits_per_text.broadcast_mul(&logit_scale)?;
|
let logits_per_text = logits_per_text.broadcast_mul(&logit_scale)?;
|
||||||
|
|
||||||
let logits_per_image = logits_per_text.t()?;
|
let logits_per_image = logits_per_text.t()?;
|
||||||
|
|
||||||
Ok((logits_per_text, logits_per_image))
|
Ok((logits_per_text, logits_per_image))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -59,7 +59,7 @@ impl ClipTextConfig {
|
|||||||
|
|
||||||
// ClipTextEmbeddings mostly based on the existing implementation in the stable diffision model.
|
// ClipTextEmbeddings mostly based on the existing implementation in the stable diffision model.
|
||||||
// TODO rewrite to be more similar to https://github.com/huggingface/transformers/blob/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip/modeling_clip.py#L142
|
// TODO rewrite to be more similar to https://github.com/huggingface/transformers/blob/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip/modeling_clip.py#L142
|
||||||
#[derive(Debug)]
|
#[derive(Clone, Debug)]
|
||||||
struct ClipTextEmbeddings {
|
struct ClipTextEmbeddings {
|
||||||
token_embedding: candle_nn::Embedding,
|
token_embedding: candle_nn::Embedding,
|
||||||
position_embedding: candle_nn::Embedding,
|
position_embedding: candle_nn::Embedding,
|
||||||
@ -70,16 +70,13 @@ impl ClipTextEmbeddings {
|
|||||||
fn new(vs: candle_nn::VarBuilder, c: &ClipTextConfig) -> Result<Self> {
|
fn new(vs: candle_nn::VarBuilder, c: &ClipTextConfig) -> Result<Self> {
|
||||||
let token_embedding =
|
let token_embedding =
|
||||||
candle_nn::embedding(c.vocab_size, c.embed_dim, vs.pp("token_embedding"))?;
|
candle_nn::embedding(c.vocab_size, c.embed_dim, vs.pp("token_embedding"))?;
|
||||||
|
|
||||||
let position_embedding: nn::Embedding = candle_nn::embedding(
|
let position_embedding: nn::Embedding = candle_nn::embedding(
|
||||||
c.max_position_embeddings,
|
c.max_position_embeddings,
|
||||||
c.embed_dim,
|
c.embed_dim,
|
||||||
vs.pp("position_embedding"),
|
vs.pp("position_embedding"),
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
let position_ids =
|
let position_ids =
|
||||||
Tensor::arange(0u32, c.max_position_embeddings as u32, vs.device())?.unsqueeze(0)?;
|
Tensor::arange(0u32, c.max_position_embeddings as u32, vs.device())?.unsqueeze(0)?;
|
||||||
|
|
||||||
Ok(ClipTextEmbeddings {
|
Ok(ClipTextEmbeddings {
|
||||||
token_embedding,
|
token_embedding,
|
||||||
position_embedding,
|
position_embedding,
|
||||||
@ -91,20 +88,14 @@ impl ClipTextEmbeddings {
|
|||||||
impl Module for ClipTextEmbeddings {
|
impl Module for ClipTextEmbeddings {
|
||||||
fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
|
fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
|
||||||
let seq_length = input_ids.dim(D::Minus1)?;
|
let seq_length = input_ids.dim(D::Minus1)?;
|
||||||
|
let inputs_embeds = self.token_embedding.forward(input_ids)?;
|
||||||
let inputs_embeds = &self.token_embedding.forward(input_ids)?;
|
let position_ids = self.position_ids.narrow(1, 0, seq_length)?;
|
||||||
|
let position_embedding = self.position_embedding.forward(&position_ids)?;
|
||||||
let postion_ids = &self.position_ids.narrow(1, 0, seq_length)?;
|
inputs_embeds.broadcast_add(&position_embedding)
|
||||||
|
|
||||||
let position_embedding = &self.position_embedding.forward(&postion_ids)?;
|
|
||||||
|
|
||||||
let inputs_embeds = inputs_embeds.broadcast_add(&position_embedding)?;
|
|
||||||
|
|
||||||
Ok(inputs_embeds)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Clone, Debug)]
|
||||||
struct ClipAttention {
|
struct ClipAttention {
|
||||||
k_proj: candle_nn::Linear,
|
k_proj: candle_nn::Linear,
|
||||||
v_proj: candle_nn::Linear,
|
v_proj: candle_nn::Linear,
|
||||||
@ -166,15 +157,10 @@ impl ClipAttention {
|
|||||||
let src_len = key_states.dim(1)?;
|
let src_len = key_states.dim(1)?;
|
||||||
|
|
||||||
let attn_weights = if let Some(causal_attention_mask) = causal_attention_mask {
|
let attn_weights = if let Some(causal_attention_mask) = causal_attention_mask {
|
||||||
let attn_reshape =
|
|
||||||
attn_weights.reshape((bsz, self.num_attention_heads, seq_len, src_len))?;
|
|
||||||
|
|
||||||
let attn_weights = attn_reshape.broadcast_add(causal_attention_mask)?;
|
|
||||||
|
|
||||||
let attn_weights =
|
|
||||||
attn_weights.reshape((bsz * self.num_attention_heads, seq_len, src_len))?;
|
|
||||||
|
|
||||||
attn_weights
|
attn_weights
|
||||||
|
.reshape((bsz, self.num_attention_heads, seq_len, src_len))?
|
||||||
|
.broadcast_add(causal_attention_mask)?
|
||||||
|
.reshape((bsz * self.num_attention_heads, seq_len, src_len))?
|
||||||
} else {
|
} else {
|
||||||
attn_weights
|
attn_weights
|
||||||
};
|
};
|
||||||
@ -190,7 +176,7 @@ impl ClipAttention {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Clone, Debug)]
|
||||||
struct ClipMlp {
|
struct ClipMlp {
|
||||||
fc1: candle_nn::Linear,
|
fc1: candle_nn::Linear,
|
||||||
fc2: candle_nn::Linear,
|
fc2: candle_nn::Linear,
|
||||||
@ -217,7 +203,7 @@ impl ClipMlp {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Clone, Debug)]
|
||||||
struct ClipEncoderLayer {
|
struct ClipEncoderLayer {
|
||||||
self_attn: ClipAttention,
|
self_attn: ClipAttention,
|
||||||
layer_norm1: candle_nn::LayerNorm,
|
layer_norm1: candle_nn::LayerNorm,
|
||||||
@ -253,7 +239,7 @@ impl ClipEncoderLayer {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Clone, Debug)]
|
||||||
pub struct ClipEncoder {
|
pub struct ClipEncoder {
|
||||||
layers: Vec<ClipEncoderLayer>,
|
layers: Vec<ClipEncoderLayer>,
|
||||||
}
|
}
|
||||||
@ -271,7 +257,6 @@ impl ClipEncoder {
|
|||||||
|
|
||||||
pub fn forward(&self, xs: &Tensor, causal_attention_mask: Option<&Tensor>) -> Result<Tensor> {
|
pub fn forward(&self, xs: &Tensor, causal_attention_mask: Option<&Tensor>) -> Result<Tensor> {
|
||||||
let mut xs = xs.clone();
|
let mut xs = xs.clone();
|
||||||
|
|
||||||
for layer in self.layers.iter() {
|
for layer in self.layers.iter() {
|
||||||
xs = layer.forward(&xs, causal_attention_mask)?;
|
xs = layer.forward(&xs, causal_attention_mask)?;
|
||||||
}
|
}
|
||||||
@ -280,7 +265,7 @@ impl ClipEncoder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// A CLIP transformer based model.
|
/// A CLIP transformer based model.
|
||||||
#[derive(Debug)]
|
#[derive(Clone, Debug)]
|
||||||
pub struct ClipTextTransformer {
|
pub struct ClipTextTransformer {
|
||||||
embeddings: ClipTextEmbeddings,
|
embeddings: ClipTextEmbeddings,
|
||||||
encoder: ClipEncoder,
|
encoder: ClipEncoder,
|
||||||
@ -292,7 +277,6 @@ impl ClipTextTransformer {
|
|||||||
let embeddings = ClipTextEmbeddings::new(vs.pp("embeddings"), c)?;
|
let embeddings = ClipTextEmbeddings::new(vs.pp("embeddings"), c)?;
|
||||||
let encoder = ClipEncoder::new(vs.pp("encoder"), &EncoderConfig::Text(c.clone()))?;
|
let encoder = ClipEncoder::new(vs.pp("encoder"), &EncoderConfig::Text(c.clone()))?;
|
||||||
let final_layer_norm = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("final_layer_norm"))?;
|
let final_layer_norm = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("final_layer_norm"))?;
|
||||||
|
|
||||||
Ok(ClipTextTransformer {
|
Ok(ClipTextTransformer {
|
||||||
embeddings,
|
embeddings,
|
||||||
encoder,
|
encoder,
|
||||||
@ -325,7 +309,6 @@ impl ClipTextTransformer {
|
|||||||
pub fn forward_with_mask(&self, input_ids: &Tensor, mask_after: usize) -> Result<Tensor> {
|
pub fn forward_with_mask(&self, input_ids: &Tensor, mask_after: usize) -> Result<Tensor> {
|
||||||
let (bsz, seq_len) = input_ids.dims2()?;
|
let (bsz, seq_len) = input_ids.dims2()?;
|
||||||
let input_ids = self.embeddings.forward(input_ids)?;
|
let input_ids = self.embeddings.forward(input_ids)?;
|
||||||
|
|
||||||
let causal_attention_mask =
|
let causal_attention_mask =
|
||||||
Self::build_causal_attention_mask(bsz, seq_len, mask_after, input_ids.device())?;
|
Self::build_causal_attention_mask(bsz, seq_len, mask_after, input_ids.device())?;
|
||||||
let input_ids = self
|
let input_ids = self
|
||||||
@ -338,18 +321,13 @@ impl ClipTextTransformer {
|
|||||||
impl Module for ClipTextTransformer {
|
impl Module for ClipTextTransformer {
|
||||||
fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
|
fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
|
||||||
let output = self.forward_with_mask(input_ids, usize::MAX)?;
|
let output = self.forward_with_mask(input_ids, usize::MAX)?;
|
||||||
|
|
||||||
let sequence_max_indices = input_ids.argmax(D::Minus1)?.to_dtype(DType::I64)?;
|
let sequence_max_indices = input_ids.argmax(D::Minus1)?.to_dtype(DType::I64)?;
|
||||||
|
|
||||||
let mut indices: Vec<Tensor> = Vec::new();
|
let mut indices = Vec::new();
|
||||||
|
|
||||||
for (batch_idx, &seq_idx) in sequence_max_indices.to_vec1::<i64>()?.iter().enumerate() {
|
for (batch_idx, &seq_idx) in sequence_max_indices.to_vec1::<i64>()?.iter().enumerate() {
|
||||||
let index = output.i((batch_idx, seq_idx as usize))?.unsqueeze(0)?;
|
let index = output.i((batch_idx, seq_idx as usize))?.unsqueeze(0)?;
|
||||||
indices.push(index);
|
indices.push(index);
|
||||||
}
|
}
|
||||||
|
Tensor::cat(&indices, 0)
|
||||||
let pooled_output = Tensor::cat(&indices, 0)?;
|
|
||||||
|
|
||||||
Ok(pooled_output)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -10,7 +10,6 @@ use candle::{IndexOp, Result, Shape, Tensor, D};
|
|||||||
use candle_nn as nn;
|
use candle_nn as nn;
|
||||||
use candle_nn::Module;
|
use candle_nn::Module;
|
||||||
use nn::Conv2dConfig;
|
use nn::Conv2dConfig;
|
||||||
use tracing::warn;
|
|
||||||
|
|
||||||
use super::{
|
use super::{
|
||||||
text_model::{Activation, ClipEncoder},
|
text_model::{Activation, ClipEncoder},
|
||||||
@ -50,7 +49,7 @@ impl ClipVisionConfig {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// https://github.com/huggingface/transformers/blob/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip/modeling_clip.py#L112
|
// https://github.com/huggingface/transformers/blob/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip/modeling_clip.py#L112
|
||||||
#[derive(Debug)]
|
#[derive(Clone, Debug)]
|
||||||
struct ClipVisionEmbeddings {
|
struct ClipVisionEmbeddings {
|
||||||
patch_embedding: candle_nn::Conv2d,
|
patch_embedding: candle_nn::Conv2d,
|
||||||
position_ids: Tensor,
|
position_ids: Tensor,
|
||||||
@ -64,14 +63,11 @@ impl ClipVisionEmbeddings {
|
|||||||
let class_embedding = if vs.contains_tensor("class_embedding") {
|
let class_embedding = if vs.contains_tensor("class_embedding") {
|
||||||
vs.get(c.embed_dim, "class_embedding")?
|
vs.get(c.embed_dim, "class_embedding")?
|
||||||
} else {
|
} else {
|
||||||
warn!("class_embedding not found in the. Initializing a new one.");
|
Tensor::randn(0f32, 1f32, c.embed_dim, vs.device())?
|
||||||
Tensor::randn(0.0 as f32, 1.0 as f32, &[c.embed_dim], vs.device())?
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let num_patches = (c.image_size / c.patch_size).pow(2);
|
let num_patches = (c.image_size / c.patch_size).pow(2);
|
||||||
|
|
||||||
let num_positions = num_patches + 1;
|
let num_positions = num_patches + 1;
|
||||||
|
|
||||||
let position_ids = Tensor::arange(0, num_positions as i64, vs.device())?;
|
let position_ids = Tensor::arange(0, num_positions as i64, vs.device())?;
|
||||||
|
|
||||||
let conv2dconfig = Conv2dConfig {
|
let conv2dconfig = Conv2dConfig {
|
||||||
@ -80,7 +76,6 @@ impl ClipVisionEmbeddings {
|
|||||||
};
|
};
|
||||||
let position_embedding =
|
let position_embedding =
|
||||||
candle_nn::embedding(num_positions, c.embed_dim, vs.pp("position_embedding"))?;
|
candle_nn::embedding(num_positions, c.embed_dim, vs.pp("position_embedding"))?;
|
||||||
|
|
||||||
let patch_embedding = candle_nn::conv2d_no_bias(
|
let patch_embedding = candle_nn::conv2d_no_bias(
|
||||||
c.num_channels,
|
c.num_channels,
|
||||||
c.embed_dim,
|
c.embed_dim,
|
||||||
@ -88,7 +83,6 @@ impl ClipVisionEmbeddings {
|
|||||||
conv2dconfig,
|
conv2dconfig,
|
||||||
vs.pp("patch_embedding"),
|
vs.pp("patch_embedding"),
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
patch_embedding,
|
patch_embedding,
|
||||||
position_ids,
|
position_ids,
|
||||||
@ -101,31 +95,21 @@ impl ClipVisionEmbeddings {
|
|||||||
impl Module for ClipVisionEmbeddings {
|
impl Module for ClipVisionEmbeddings {
|
||||||
fn forward(&self, pixel_values: &Tensor) -> Result<Tensor> {
|
fn forward(&self, pixel_values: &Tensor) -> Result<Tensor> {
|
||||||
let batch_size = pixel_values.shape().dims();
|
let batch_size = pixel_values.shape().dims();
|
||||||
|
let patch_embeds = self
|
||||||
let patch_embeds = self.patch_embedding.forward(&pixel_values)?;
|
.patch_embedding
|
||||||
|
.forward(pixel_values)?
|
||||||
let patch_embeds = patch_embeds.flatten_from(2)?;
|
.flatten_from(2)?
|
||||||
|
.transpose(1, 2)?;
|
||||||
let patch_embeds = patch_embeds.transpose(1, 2)?;
|
let shape = Shape::from((batch_size[0], 1, self.class_embedding.dim(D::Minus1)?));
|
||||||
|
let class_embeds = self.class_embedding.expand(shape)?;
|
||||||
let class_embedding = self.class_embedding.clone();
|
|
||||||
|
|
||||||
let shape = Shape::from(vec![batch_size[0], 1, class_embedding.dim(D::Minus1)?]);
|
|
||||||
|
|
||||||
let class_embeds = class_embedding.expand(shape)?;
|
|
||||||
|
|
||||||
let embeddings = Tensor::cat(&[class_embeds, patch_embeds], 1)?;
|
let embeddings = Tensor::cat(&[class_embeds, patch_embeds], 1)?;
|
||||||
|
|
||||||
let position_embedding = self.position_embedding.forward(&self.position_ids)?;
|
let position_embedding = self.position_embedding.forward(&self.position_ids)?;
|
||||||
|
embeddings.broadcast_add(&position_embedding)
|
||||||
let embeddings = embeddings.broadcast_add(&position_embedding)?;
|
|
||||||
|
|
||||||
Ok(embeddings)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// https://github.com/huggingface/transformers/blob/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip/modeling_clip.py#L743
|
// https://github.com/huggingface/transformers/blob/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip/modeling_clip.py#L743
|
||||||
#[derive(Debug)]
|
#[derive(Clone, Debug)]
|
||||||
pub struct ClipVisionTransformer {
|
pub struct ClipVisionTransformer {
|
||||||
embeddings: ClipVisionEmbeddings,
|
embeddings: ClipVisionEmbeddings,
|
||||||
encoder: ClipEncoder,
|
encoder: ClipEncoder,
|
||||||
@ -136,13 +120,9 @@ pub struct ClipVisionTransformer {
|
|||||||
impl ClipVisionTransformer {
|
impl ClipVisionTransformer {
|
||||||
pub fn new(vs: candle_nn::VarBuilder, c: &ClipVisionConfig) -> Result<Self> {
|
pub fn new(vs: candle_nn::VarBuilder, c: &ClipVisionConfig) -> Result<Self> {
|
||||||
let embeddings = ClipVisionEmbeddings::new(vs.pp("embeddings"), c)?;
|
let embeddings = ClipVisionEmbeddings::new(vs.pp("embeddings"), c)?;
|
||||||
|
|
||||||
let pre_layer_norm = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("pre_layrnorm"))?;
|
let pre_layer_norm = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("pre_layrnorm"))?;
|
||||||
|
|
||||||
let encoder = ClipEncoder::new(vs.pp("encoder"), &EncoderConfig::Vision(c.clone()))?;
|
let encoder = ClipEncoder::new(vs.pp("encoder"), &EncoderConfig::Vision(c.clone()))?;
|
||||||
|
|
||||||
let final_layer_norm = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("post_layernorm"))?;
|
let final_layer_norm = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("post_layernorm"))?;
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
embeddings,
|
embeddings,
|
||||||
encoder,
|
encoder,
|
||||||
@ -154,18 +134,14 @@ impl ClipVisionTransformer {
|
|||||||
|
|
||||||
impl Module for ClipVisionTransformer {
|
impl Module for ClipVisionTransformer {
|
||||||
fn forward(&self, pixel_values: &Tensor) -> Result<Tensor> {
|
fn forward(&self, pixel_values: &Tensor) -> Result<Tensor> {
|
||||||
let hidden_states = self.embeddings.forward(pixel_values)?;
|
let hidden_states = pixel_values
|
||||||
|
.apply(&self.embeddings)?
|
||||||
let hidden_states = self.pre_layer_norm.forward(&hidden_states)?;
|
.apply(&self.pre_layer_norm)?;
|
||||||
|
|
||||||
let encoder_outputs = self.encoder.forward(&hidden_states, None)?;
|
let encoder_outputs = self.encoder.forward(&hidden_states, None)?;
|
||||||
|
|
||||||
// https://github.com/huggingface/transformers/blob/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip/modeling_clip.py#L787
|
// https://github.com/huggingface/transformers/blob/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip/modeling_clip.py#L787
|
||||||
// pooled_output = encoder_outputs[:, 0, :]
|
// pooled_output = encoder_outputs[:, 0, :]
|
||||||
let pooled_output = encoder_outputs.i((.., 0, ..))?;
|
let pooled_output = encoder_outputs.i((.., 0, ..))?;
|
||||||
|
self.final_layer_norm.forward(&pooled_output)
|
||||||
let output = self.final_layer_norm.forward(&pooled_output)?;
|
|
||||||
|
|
||||||
Ok(output)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -3,6 +3,7 @@ pub mod bigcode;
|
|||||||
pub mod blip;
|
pub mod blip;
|
||||||
pub mod blip_text;
|
pub mod blip_text;
|
||||||
pub mod chatglm;
|
pub mod chatglm;
|
||||||
|
pub mod clip;
|
||||||
pub mod convmixer;
|
pub mod convmixer;
|
||||||
pub mod convnext;
|
pub mod convnext;
|
||||||
pub mod dinov2;
|
pub mod dinov2;
|
||||||
@ -12,7 +13,6 @@ pub mod efficientvit;
|
|||||||
pub mod encodec;
|
pub mod encodec;
|
||||||
pub mod falcon;
|
pub mod falcon;
|
||||||
pub mod gemma;
|
pub mod gemma;
|
||||||
pub mod clip;
|
|
||||||
pub mod jina_bert;
|
pub mod jina_bert;
|
||||||
pub mod llama;
|
pub mod llama;
|
||||||
pub mod llama2_c;
|
pub mod llama2_c;
|
||||||
|
Reference in New Issue
Block a user