Fix clippy lints + minor cleanups. (#1957)

* Fix clippy lints + minor cleanups.

* fmt.

* Derive clone.
This commit is contained in:
Laurent Mazare
2024-03-28 14:17:46 +01:00
committed by GitHub
parent b0340d72ec
commit cdc8b57b5c
4 changed files with 41 additions and 100 deletions

View File

@ -10,13 +10,11 @@ use self::{
vision_model::ClipVisionTransformer, vision_model::ClipVisionTransformer,
}; };
use candle::{Result, Tensor, D}; use candle::{Result, Tensor, D};
use candle_nn::Module;
use tracing::warn;
pub mod text_model; pub mod text_model;
pub mod vision_model; pub mod vision_model;
#[derive(Clone, Debug)]
pub struct ClipModel { pub struct ClipModel {
text_model: ClipTextTransformer, text_model: ClipTextTransformer,
vision_model: ClipVisionTransformer, vision_model: ClipVisionTransformer,
@ -25,6 +23,7 @@ pub struct ClipModel {
logit_scale: Tensor, logit_scale: Tensor,
} }
#[derive(Clone, Debug)]
pub enum EncoderConfig { pub enum EncoderConfig {
Text(text_model::ClipTextConfig), Text(text_model::ClipTextConfig),
Vision(vision_model::ClipVisionConfig), Vision(vision_model::ClipVisionConfig),
@ -67,6 +66,7 @@ impl EncoderConfig {
} }
} }
#[derive(Clone, Debug)]
pub struct ClipConfig { pub struct ClipConfig {
pub text_config: text_model::ClipTextConfig, pub text_config: text_model::ClipTextConfig,
pub vision_config: vision_model::ClipVisionConfig, pub vision_config: vision_model::ClipVisionConfig,
@ -111,7 +111,6 @@ impl ClipModel {
let logit_scale = if vs.contains_tensor("logit_scale") { let logit_scale = if vs.contains_tensor("logit_scale") {
vs.get(&[], "logit_scale")? vs.get(&[], "logit_scale")?
} else { } else {
warn!("Creating logit_scale tensor, results may vary.");
Tensor::new(&[c.logit_scale_init_value], vs.device())? Tensor::new(&[c.logit_scale_init_value], vs.device())?
}; };
@ -125,38 +124,26 @@ impl ClipModel {
} }
pub fn get_text_features(&self, input_ids: &Tensor) -> Result<Tensor> { pub fn get_text_features(&self, input_ids: &Tensor) -> Result<Tensor> {
let text_outputs = self.text_model.forward(input_ids)?; input_ids
.apply(&self.text_model)?
let text_features = self.text_projection.forward(&text_outputs)?; .apply(&self.text_projection)
Ok(text_features)
} }
pub fn get_image_features(&self, pixel_values: &Tensor) -> Result<Tensor> { pub fn get_image_features(&self, pixel_values: &Tensor) -> Result<Tensor> {
let image_features = self.vision_model.forward(pixel_values)?; pixel_values
.apply(&self.vision_model)?
let image_features = self.visual_projection.forward(&image_features)?; .apply(&self.visual_projection)
Ok(image_features)
} }
pub fn forward(&self, pixel_values: &Tensor, input_ids: &Tensor) -> Result<(Tensor, Tensor)> { pub fn forward(&self, pixel_values: &Tensor, input_ids: &Tensor) -> Result<(Tensor, Tensor)> {
let image_features = self.get_image_features(pixel_values)?; let image_features = self.get_image_features(pixel_values)?;
let text_features = self.get_text_features(input_ids)?; let text_features = self.get_text_features(input_ids)?;
let image_features_normalized = div_l2_norm(&image_features)?; let image_features_normalized = div_l2_norm(&image_features)?;
let text_features_normalized = div_l2_norm(&text_features)?; let text_features_normalized = div_l2_norm(&text_features)?;
let logits_per_text = text_features_normalized.matmul(&image_features_normalized.t()?)?; let logits_per_text = text_features_normalized.matmul(&image_features_normalized.t()?)?;
let logit_scale = self.logit_scale.exp()?;
let logit_scale = &self.logit_scale.exp()?;
let logits_per_text = logits_per_text.broadcast_mul(&logit_scale)?; let logits_per_text = logits_per_text.broadcast_mul(&logit_scale)?;
let logits_per_image = logits_per_text.t()?; let logits_per_image = logits_per_text.t()?;
Ok((logits_per_text, logits_per_image)) Ok((logits_per_text, logits_per_image))
} }
} }

View File

@ -59,7 +59,7 @@ impl ClipTextConfig {
// ClipTextEmbeddings mostly based on the existing implementation in the stable diffision model. // ClipTextEmbeddings mostly based on the existing implementation in the stable diffision model.
// TODO rewrite to be more similar to https://github.com/huggingface/transformers/blob/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip/modeling_clip.py#L142 // TODO rewrite to be more similar to https://github.com/huggingface/transformers/blob/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip/modeling_clip.py#L142
#[derive(Debug)] #[derive(Clone, Debug)]
struct ClipTextEmbeddings { struct ClipTextEmbeddings {
token_embedding: candle_nn::Embedding, token_embedding: candle_nn::Embedding,
position_embedding: candle_nn::Embedding, position_embedding: candle_nn::Embedding,
@ -70,16 +70,13 @@ impl ClipTextEmbeddings {
fn new(vs: candle_nn::VarBuilder, c: &ClipTextConfig) -> Result<Self> { fn new(vs: candle_nn::VarBuilder, c: &ClipTextConfig) -> Result<Self> {
let token_embedding = let token_embedding =
candle_nn::embedding(c.vocab_size, c.embed_dim, vs.pp("token_embedding"))?; candle_nn::embedding(c.vocab_size, c.embed_dim, vs.pp("token_embedding"))?;
let position_embedding: nn::Embedding = candle_nn::embedding( let position_embedding: nn::Embedding = candle_nn::embedding(
c.max_position_embeddings, c.max_position_embeddings,
c.embed_dim, c.embed_dim,
vs.pp("position_embedding"), vs.pp("position_embedding"),
)?; )?;
let position_ids = let position_ids =
Tensor::arange(0u32, c.max_position_embeddings as u32, vs.device())?.unsqueeze(0)?; Tensor::arange(0u32, c.max_position_embeddings as u32, vs.device())?.unsqueeze(0)?;
Ok(ClipTextEmbeddings { Ok(ClipTextEmbeddings {
token_embedding, token_embedding,
position_embedding, position_embedding,
@ -91,20 +88,14 @@ impl ClipTextEmbeddings {
impl Module for ClipTextEmbeddings { impl Module for ClipTextEmbeddings {
fn forward(&self, input_ids: &Tensor) -> Result<Tensor> { fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
let seq_length = input_ids.dim(D::Minus1)?; let seq_length = input_ids.dim(D::Minus1)?;
let inputs_embeds = self.token_embedding.forward(input_ids)?;
let inputs_embeds = &self.token_embedding.forward(input_ids)?; let position_ids = self.position_ids.narrow(1, 0, seq_length)?;
let position_embedding = self.position_embedding.forward(&position_ids)?;
let postion_ids = &self.position_ids.narrow(1, 0, seq_length)?; inputs_embeds.broadcast_add(&position_embedding)
let position_embedding = &self.position_embedding.forward(&postion_ids)?;
let inputs_embeds = inputs_embeds.broadcast_add(&position_embedding)?;
Ok(inputs_embeds)
} }
} }
#[derive(Debug)] #[derive(Clone, Debug)]
struct ClipAttention { struct ClipAttention {
k_proj: candle_nn::Linear, k_proj: candle_nn::Linear,
v_proj: candle_nn::Linear, v_proj: candle_nn::Linear,
@ -166,15 +157,10 @@ impl ClipAttention {
let src_len = key_states.dim(1)?; let src_len = key_states.dim(1)?;
let attn_weights = if let Some(causal_attention_mask) = causal_attention_mask { let attn_weights = if let Some(causal_attention_mask) = causal_attention_mask {
let attn_reshape =
attn_weights.reshape((bsz, self.num_attention_heads, seq_len, src_len))?;
let attn_weights = attn_reshape.broadcast_add(causal_attention_mask)?;
let attn_weights =
attn_weights.reshape((bsz * self.num_attention_heads, seq_len, src_len))?;
attn_weights attn_weights
.reshape((bsz, self.num_attention_heads, seq_len, src_len))?
.broadcast_add(causal_attention_mask)?
.reshape((bsz * self.num_attention_heads, seq_len, src_len))?
} else { } else {
attn_weights attn_weights
}; };
@ -190,7 +176,7 @@ impl ClipAttention {
} }
} }
#[derive(Debug)] #[derive(Clone, Debug)]
struct ClipMlp { struct ClipMlp {
fc1: candle_nn::Linear, fc1: candle_nn::Linear,
fc2: candle_nn::Linear, fc2: candle_nn::Linear,
@ -217,7 +203,7 @@ impl ClipMlp {
} }
} }
#[derive(Debug)] #[derive(Clone, Debug)]
struct ClipEncoderLayer { struct ClipEncoderLayer {
self_attn: ClipAttention, self_attn: ClipAttention,
layer_norm1: candle_nn::LayerNorm, layer_norm1: candle_nn::LayerNorm,
@ -253,7 +239,7 @@ impl ClipEncoderLayer {
} }
} }
#[derive(Debug)] #[derive(Clone, Debug)]
pub struct ClipEncoder { pub struct ClipEncoder {
layers: Vec<ClipEncoderLayer>, layers: Vec<ClipEncoderLayer>,
} }
@ -271,7 +257,6 @@ impl ClipEncoder {
pub fn forward(&self, xs: &Tensor, causal_attention_mask: Option<&Tensor>) -> Result<Tensor> { pub fn forward(&self, xs: &Tensor, causal_attention_mask: Option<&Tensor>) -> Result<Tensor> {
let mut xs = xs.clone(); let mut xs = xs.clone();
for layer in self.layers.iter() { for layer in self.layers.iter() {
xs = layer.forward(&xs, causal_attention_mask)?; xs = layer.forward(&xs, causal_attention_mask)?;
} }
@ -280,7 +265,7 @@ impl ClipEncoder {
} }
/// A CLIP transformer based model. /// A CLIP transformer based model.
#[derive(Debug)] #[derive(Clone, Debug)]
pub struct ClipTextTransformer { pub struct ClipTextTransformer {
embeddings: ClipTextEmbeddings, embeddings: ClipTextEmbeddings,
encoder: ClipEncoder, encoder: ClipEncoder,
@ -292,7 +277,6 @@ impl ClipTextTransformer {
let embeddings = ClipTextEmbeddings::new(vs.pp("embeddings"), c)?; let embeddings = ClipTextEmbeddings::new(vs.pp("embeddings"), c)?;
let encoder = ClipEncoder::new(vs.pp("encoder"), &EncoderConfig::Text(c.clone()))?; let encoder = ClipEncoder::new(vs.pp("encoder"), &EncoderConfig::Text(c.clone()))?;
let final_layer_norm = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("final_layer_norm"))?; let final_layer_norm = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("final_layer_norm"))?;
Ok(ClipTextTransformer { Ok(ClipTextTransformer {
embeddings, embeddings,
encoder, encoder,
@ -325,7 +309,6 @@ impl ClipTextTransformer {
pub fn forward_with_mask(&self, input_ids: &Tensor, mask_after: usize) -> Result<Tensor> { pub fn forward_with_mask(&self, input_ids: &Tensor, mask_after: usize) -> Result<Tensor> {
let (bsz, seq_len) = input_ids.dims2()?; let (bsz, seq_len) = input_ids.dims2()?;
let input_ids = self.embeddings.forward(input_ids)?; let input_ids = self.embeddings.forward(input_ids)?;
let causal_attention_mask = let causal_attention_mask =
Self::build_causal_attention_mask(bsz, seq_len, mask_after, input_ids.device())?; Self::build_causal_attention_mask(bsz, seq_len, mask_after, input_ids.device())?;
let input_ids = self let input_ids = self
@ -338,18 +321,13 @@ impl ClipTextTransformer {
impl Module for ClipTextTransformer { impl Module for ClipTextTransformer {
fn forward(&self, input_ids: &Tensor) -> Result<Tensor> { fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
let output = self.forward_with_mask(input_ids, usize::MAX)?; let output = self.forward_with_mask(input_ids, usize::MAX)?;
let sequence_max_indices = input_ids.argmax(D::Minus1)?.to_dtype(DType::I64)?; let sequence_max_indices = input_ids.argmax(D::Minus1)?.to_dtype(DType::I64)?;
let mut indices: Vec<Tensor> = Vec::new(); let mut indices = Vec::new();
for (batch_idx, &seq_idx) in sequence_max_indices.to_vec1::<i64>()?.iter().enumerate() { for (batch_idx, &seq_idx) in sequence_max_indices.to_vec1::<i64>()?.iter().enumerate() {
let index = output.i((batch_idx, seq_idx as usize))?.unsqueeze(0)?; let index = output.i((batch_idx, seq_idx as usize))?.unsqueeze(0)?;
indices.push(index); indices.push(index);
} }
Tensor::cat(&indices, 0)
let pooled_output = Tensor::cat(&indices, 0)?;
Ok(pooled_output)
} }
} }

View File

@ -10,7 +10,6 @@ use candle::{IndexOp, Result, Shape, Tensor, D};
use candle_nn as nn; use candle_nn as nn;
use candle_nn::Module; use candle_nn::Module;
use nn::Conv2dConfig; use nn::Conv2dConfig;
use tracing::warn;
use super::{ use super::{
text_model::{Activation, ClipEncoder}, text_model::{Activation, ClipEncoder},
@ -50,7 +49,7 @@ impl ClipVisionConfig {
} }
// https://github.com/huggingface/transformers/blob/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip/modeling_clip.py#L112 // https://github.com/huggingface/transformers/blob/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip/modeling_clip.py#L112
#[derive(Debug)] #[derive(Clone, Debug)]
struct ClipVisionEmbeddings { struct ClipVisionEmbeddings {
patch_embedding: candle_nn::Conv2d, patch_embedding: candle_nn::Conv2d,
position_ids: Tensor, position_ids: Tensor,
@ -64,14 +63,11 @@ impl ClipVisionEmbeddings {
let class_embedding = if vs.contains_tensor("class_embedding") { let class_embedding = if vs.contains_tensor("class_embedding") {
vs.get(c.embed_dim, "class_embedding")? vs.get(c.embed_dim, "class_embedding")?
} else { } else {
warn!("class_embedding not found in the. Initializing a new one."); Tensor::randn(0f32, 1f32, c.embed_dim, vs.device())?
Tensor::randn(0.0 as f32, 1.0 as f32, &[c.embed_dim], vs.device())?
}; };
let num_patches = (c.image_size / c.patch_size).pow(2); let num_patches = (c.image_size / c.patch_size).pow(2);
let num_positions = num_patches + 1; let num_positions = num_patches + 1;
let position_ids = Tensor::arange(0, num_positions as i64, vs.device())?; let position_ids = Tensor::arange(0, num_positions as i64, vs.device())?;
let conv2dconfig = Conv2dConfig { let conv2dconfig = Conv2dConfig {
@ -80,7 +76,6 @@ impl ClipVisionEmbeddings {
}; };
let position_embedding = let position_embedding =
candle_nn::embedding(num_positions, c.embed_dim, vs.pp("position_embedding"))?; candle_nn::embedding(num_positions, c.embed_dim, vs.pp("position_embedding"))?;
let patch_embedding = candle_nn::conv2d_no_bias( let patch_embedding = candle_nn::conv2d_no_bias(
c.num_channels, c.num_channels,
c.embed_dim, c.embed_dim,
@ -88,7 +83,6 @@ impl ClipVisionEmbeddings {
conv2dconfig, conv2dconfig,
vs.pp("patch_embedding"), vs.pp("patch_embedding"),
)?; )?;
Ok(Self { Ok(Self {
patch_embedding, patch_embedding,
position_ids, position_ids,
@ -101,31 +95,21 @@ impl ClipVisionEmbeddings {
impl Module for ClipVisionEmbeddings { impl Module for ClipVisionEmbeddings {
fn forward(&self, pixel_values: &Tensor) -> Result<Tensor> { fn forward(&self, pixel_values: &Tensor) -> Result<Tensor> {
let batch_size = pixel_values.shape().dims(); let batch_size = pixel_values.shape().dims();
let patch_embeds = self
let patch_embeds = self.patch_embedding.forward(&pixel_values)?; .patch_embedding
.forward(pixel_values)?
let patch_embeds = patch_embeds.flatten_from(2)?; .flatten_from(2)?
.transpose(1, 2)?;
let patch_embeds = patch_embeds.transpose(1, 2)?; let shape = Shape::from((batch_size[0], 1, self.class_embedding.dim(D::Minus1)?));
let class_embeds = self.class_embedding.expand(shape)?;
let class_embedding = self.class_embedding.clone();
let shape = Shape::from(vec![batch_size[0], 1, class_embedding.dim(D::Minus1)?]);
let class_embeds = class_embedding.expand(shape)?;
let embeddings = Tensor::cat(&[class_embeds, patch_embeds], 1)?; let embeddings = Tensor::cat(&[class_embeds, patch_embeds], 1)?;
let position_embedding = self.position_embedding.forward(&self.position_ids)?; let position_embedding = self.position_embedding.forward(&self.position_ids)?;
embeddings.broadcast_add(&position_embedding)
let embeddings = embeddings.broadcast_add(&position_embedding)?;
Ok(embeddings)
} }
} }
// https://github.com/huggingface/transformers/blob/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip/modeling_clip.py#L743 // https://github.com/huggingface/transformers/blob/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip/modeling_clip.py#L743
#[derive(Debug)] #[derive(Clone, Debug)]
pub struct ClipVisionTransformer { pub struct ClipVisionTransformer {
embeddings: ClipVisionEmbeddings, embeddings: ClipVisionEmbeddings,
encoder: ClipEncoder, encoder: ClipEncoder,
@ -136,13 +120,9 @@ pub struct ClipVisionTransformer {
impl ClipVisionTransformer { impl ClipVisionTransformer {
pub fn new(vs: candle_nn::VarBuilder, c: &ClipVisionConfig) -> Result<Self> { pub fn new(vs: candle_nn::VarBuilder, c: &ClipVisionConfig) -> Result<Self> {
let embeddings = ClipVisionEmbeddings::new(vs.pp("embeddings"), c)?; let embeddings = ClipVisionEmbeddings::new(vs.pp("embeddings"), c)?;
let pre_layer_norm = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("pre_layrnorm"))?; let pre_layer_norm = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("pre_layrnorm"))?;
let encoder = ClipEncoder::new(vs.pp("encoder"), &EncoderConfig::Vision(c.clone()))?; let encoder = ClipEncoder::new(vs.pp("encoder"), &EncoderConfig::Vision(c.clone()))?;
let final_layer_norm = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("post_layernorm"))?; let final_layer_norm = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("post_layernorm"))?;
Ok(Self { Ok(Self {
embeddings, embeddings,
encoder, encoder,
@ -154,18 +134,14 @@ impl ClipVisionTransformer {
impl Module for ClipVisionTransformer { impl Module for ClipVisionTransformer {
fn forward(&self, pixel_values: &Tensor) -> Result<Tensor> { fn forward(&self, pixel_values: &Tensor) -> Result<Tensor> {
let hidden_states = self.embeddings.forward(pixel_values)?; let hidden_states = pixel_values
.apply(&self.embeddings)?
let hidden_states = self.pre_layer_norm.forward(&hidden_states)?; .apply(&self.pre_layer_norm)?;
let encoder_outputs = self.encoder.forward(&hidden_states, None)?; let encoder_outputs = self.encoder.forward(&hidden_states, None)?;
// https://github.com/huggingface/transformers/blob/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip/modeling_clip.py#L787 // https://github.com/huggingface/transformers/blob/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip/modeling_clip.py#L787
// pooled_output = encoder_outputs[:, 0, :] // pooled_output = encoder_outputs[:, 0, :]
let pooled_output = encoder_outputs.i((.., 0, ..))?; let pooled_output = encoder_outputs.i((.., 0, ..))?;
self.final_layer_norm.forward(&pooled_output)
let output = self.final_layer_norm.forward(&pooled_output)?;
Ok(output)
} }
} }

View File

@ -3,6 +3,7 @@ pub mod bigcode;
pub mod blip; pub mod blip;
pub mod blip_text; pub mod blip_text;
pub mod chatglm; pub mod chatglm;
pub mod clip;
pub mod convmixer; pub mod convmixer;
pub mod convnext; pub mod convnext;
pub mod dinov2; pub mod dinov2;
@ -12,7 +13,6 @@ pub mod efficientvit;
pub mod encodec; pub mod encodec;
pub mod falcon; pub mod falcon;
pub mod gemma; pub mod gemma;
pub mod clip;
pub mod jina_bert; pub mod jina_bert;
pub mod llama; pub mod llama;
pub mod llama2_c; pub mod llama2_c;