From cdc8b57b5cf28ad92642b076d67e610bdb958b2d Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Thu, 28 Mar 2024 14:17:46 +0100 Subject: [PATCH] Fix clippy lints + minor cleanups. (#1957) * Fix clippy lints + minor cleanups. * fmt. * Derive clone. --- candle-transformers/src/models/clip/mod.rs | 33 ++++-------- .../src/models/clip/text_model.rs | 52 ++++++------------ .../src/models/clip/vision_model.rs | 54 ++++++------------- candle-transformers/src/models/mod.rs | 2 +- 4 files changed, 41 insertions(+), 100 deletions(-) diff --git a/candle-transformers/src/models/clip/mod.rs b/candle-transformers/src/models/clip/mod.rs index 02df782b..9613fdab 100644 --- a/candle-transformers/src/models/clip/mod.rs +++ b/candle-transformers/src/models/clip/mod.rs @@ -10,13 +10,11 @@ use self::{ vision_model::ClipVisionTransformer, }; use candle::{Result, Tensor, D}; -use candle_nn::Module; - -use tracing::warn; pub mod text_model; pub mod vision_model; +#[derive(Clone, Debug)] pub struct ClipModel { text_model: ClipTextTransformer, vision_model: ClipVisionTransformer, @@ -25,6 +23,7 @@ pub struct ClipModel { logit_scale: Tensor, } +#[derive(Clone, Debug)] pub enum EncoderConfig { Text(text_model::ClipTextConfig), Vision(vision_model::ClipVisionConfig), @@ -67,6 +66,7 @@ impl EncoderConfig { } } +#[derive(Clone, Debug)] pub struct ClipConfig { pub text_config: text_model::ClipTextConfig, pub vision_config: vision_model::ClipVisionConfig, @@ -111,7 +111,6 @@ impl ClipModel { let logit_scale = if vs.contains_tensor("logit_scale") { vs.get(&[], "logit_scale")? } else { - warn!("Creating logit_scale tensor, results may vary."); Tensor::new(&[c.logit_scale_init_value], vs.device())? }; @@ -125,38 +124,26 @@ impl ClipModel { } pub fn get_text_features(&self, input_ids: &Tensor) -> Result { - let text_outputs = self.text_model.forward(input_ids)?; - - let text_features = self.text_projection.forward(&text_outputs)?; - - Ok(text_features) + input_ids + .apply(&self.text_model)? + .apply(&self.text_projection) } pub fn get_image_features(&self, pixel_values: &Tensor) -> Result { - let image_features = self.vision_model.forward(pixel_values)?; - - let image_features = self.visual_projection.forward(&image_features)?; - - Ok(image_features) + pixel_values + .apply(&self.vision_model)? + .apply(&self.visual_projection) } pub fn forward(&self, pixel_values: &Tensor, input_ids: &Tensor) -> Result<(Tensor, Tensor)> { let image_features = self.get_image_features(pixel_values)?; - let text_features = self.get_text_features(input_ids)?; - let image_features_normalized = div_l2_norm(&image_features)?; - let text_features_normalized = div_l2_norm(&text_features)?; - let logits_per_text = text_features_normalized.matmul(&image_features_normalized.t()?)?; - - let logit_scale = &self.logit_scale.exp()?; - + let logit_scale = self.logit_scale.exp()?; let logits_per_text = logits_per_text.broadcast_mul(&logit_scale)?; - let logits_per_image = logits_per_text.t()?; - Ok((logits_per_text, logits_per_image)) } } diff --git a/candle-transformers/src/models/clip/text_model.rs b/candle-transformers/src/models/clip/text_model.rs index 852d3e24..d3ba26ff 100644 --- a/candle-transformers/src/models/clip/text_model.rs +++ b/candle-transformers/src/models/clip/text_model.rs @@ -59,7 +59,7 @@ impl ClipTextConfig { // ClipTextEmbeddings mostly based on the existing implementation in the stable diffision model. // TODO rewrite to be more similar to https://github.com/huggingface/transformers/blob/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip/modeling_clip.py#L142 -#[derive(Debug)] +#[derive(Clone, Debug)] struct ClipTextEmbeddings { token_embedding: candle_nn::Embedding, position_embedding: candle_nn::Embedding, @@ -70,16 +70,13 @@ impl ClipTextEmbeddings { fn new(vs: candle_nn::VarBuilder, c: &ClipTextConfig) -> Result { let token_embedding = candle_nn::embedding(c.vocab_size, c.embed_dim, vs.pp("token_embedding"))?; - let position_embedding: nn::Embedding = candle_nn::embedding( c.max_position_embeddings, c.embed_dim, vs.pp("position_embedding"), )?; - let position_ids = Tensor::arange(0u32, c.max_position_embeddings as u32, vs.device())?.unsqueeze(0)?; - Ok(ClipTextEmbeddings { token_embedding, position_embedding, @@ -91,20 +88,14 @@ impl ClipTextEmbeddings { impl Module for ClipTextEmbeddings { fn forward(&self, input_ids: &Tensor) -> Result { let seq_length = input_ids.dim(D::Minus1)?; - - let inputs_embeds = &self.token_embedding.forward(input_ids)?; - - let postion_ids = &self.position_ids.narrow(1, 0, seq_length)?; - - let position_embedding = &self.position_embedding.forward(&postion_ids)?; - - let inputs_embeds = inputs_embeds.broadcast_add(&position_embedding)?; - - Ok(inputs_embeds) + let inputs_embeds = self.token_embedding.forward(input_ids)?; + let position_ids = self.position_ids.narrow(1, 0, seq_length)?; + let position_embedding = self.position_embedding.forward(&position_ids)?; + inputs_embeds.broadcast_add(&position_embedding) } } -#[derive(Debug)] +#[derive(Clone, Debug)] struct ClipAttention { k_proj: candle_nn::Linear, v_proj: candle_nn::Linear, @@ -166,15 +157,10 @@ impl ClipAttention { let src_len = key_states.dim(1)?; let attn_weights = if let Some(causal_attention_mask) = causal_attention_mask { - let attn_reshape = - attn_weights.reshape((bsz, self.num_attention_heads, seq_len, src_len))?; - - let attn_weights = attn_reshape.broadcast_add(causal_attention_mask)?; - - let attn_weights = - attn_weights.reshape((bsz * self.num_attention_heads, seq_len, src_len))?; - attn_weights + .reshape((bsz, self.num_attention_heads, seq_len, src_len))? + .broadcast_add(causal_attention_mask)? + .reshape((bsz * self.num_attention_heads, seq_len, src_len))? } else { attn_weights }; @@ -190,7 +176,7 @@ impl ClipAttention { } } -#[derive(Debug)] +#[derive(Clone, Debug)] struct ClipMlp { fc1: candle_nn::Linear, fc2: candle_nn::Linear, @@ -217,7 +203,7 @@ impl ClipMlp { } } -#[derive(Debug)] +#[derive(Clone, Debug)] struct ClipEncoderLayer { self_attn: ClipAttention, layer_norm1: candle_nn::LayerNorm, @@ -253,7 +239,7 @@ impl ClipEncoderLayer { } } -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct ClipEncoder { layers: Vec, } @@ -271,7 +257,6 @@ impl ClipEncoder { pub fn forward(&self, xs: &Tensor, causal_attention_mask: Option<&Tensor>) -> Result { let mut xs = xs.clone(); - for layer in self.layers.iter() { xs = layer.forward(&xs, causal_attention_mask)?; } @@ -280,7 +265,7 @@ impl ClipEncoder { } /// A CLIP transformer based model. -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct ClipTextTransformer { embeddings: ClipTextEmbeddings, encoder: ClipEncoder, @@ -292,7 +277,6 @@ impl ClipTextTransformer { let embeddings = ClipTextEmbeddings::new(vs.pp("embeddings"), c)?; let encoder = ClipEncoder::new(vs.pp("encoder"), &EncoderConfig::Text(c.clone()))?; let final_layer_norm = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("final_layer_norm"))?; - Ok(ClipTextTransformer { embeddings, encoder, @@ -325,7 +309,6 @@ impl ClipTextTransformer { pub fn forward_with_mask(&self, input_ids: &Tensor, mask_after: usize) -> Result { let (bsz, seq_len) = input_ids.dims2()?; let input_ids = self.embeddings.forward(input_ids)?; - let causal_attention_mask = Self::build_causal_attention_mask(bsz, seq_len, mask_after, input_ids.device())?; let input_ids = self @@ -338,18 +321,13 @@ impl ClipTextTransformer { impl Module for ClipTextTransformer { fn forward(&self, input_ids: &Tensor) -> Result { let output = self.forward_with_mask(input_ids, usize::MAX)?; - let sequence_max_indices = input_ids.argmax(D::Minus1)?.to_dtype(DType::I64)?; - let mut indices: Vec = Vec::new(); - + let mut indices = Vec::new(); for (batch_idx, &seq_idx) in sequence_max_indices.to_vec1::()?.iter().enumerate() { let index = output.i((batch_idx, seq_idx as usize))?.unsqueeze(0)?; indices.push(index); } - - let pooled_output = Tensor::cat(&indices, 0)?; - - Ok(pooled_output) + Tensor::cat(&indices, 0) } } diff --git a/candle-transformers/src/models/clip/vision_model.rs b/candle-transformers/src/models/clip/vision_model.rs index af9af7ae..88992434 100644 --- a/candle-transformers/src/models/clip/vision_model.rs +++ b/candle-transformers/src/models/clip/vision_model.rs @@ -10,7 +10,6 @@ use candle::{IndexOp, Result, Shape, Tensor, D}; use candle_nn as nn; use candle_nn::Module; use nn::Conv2dConfig; -use tracing::warn; use super::{ text_model::{Activation, ClipEncoder}, @@ -50,7 +49,7 @@ impl ClipVisionConfig { } // https://github.com/huggingface/transformers/blob/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip/modeling_clip.py#L112 -#[derive(Debug)] +#[derive(Clone, Debug)] struct ClipVisionEmbeddings { patch_embedding: candle_nn::Conv2d, position_ids: Tensor, @@ -64,14 +63,11 @@ impl ClipVisionEmbeddings { let class_embedding = if vs.contains_tensor("class_embedding") { vs.get(c.embed_dim, "class_embedding")? } else { - warn!("class_embedding not found in the. Initializing a new one."); - Tensor::randn(0.0 as f32, 1.0 as f32, &[c.embed_dim], vs.device())? + Tensor::randn(0f32, 1f32, c.embed_dim, vs.device())? }; let num_patches = (c.image_size / c.patch_size).pow(2); - let num_positions = num_patches + 1; - let position_ids = Tensor::arange(0, num_positions as i64, vs.device())?; let conv2dconfig = Conv2dConfig { @@ -80,7 +76,6 @@ impl ClipVisionEmbeddings { }; let position_embedding = candle_nn::embedding(num_positions, c.embed_dim, vs.pp("position_embedding"))?; - let patch_embedding = candle_nn::conv2d_no_bias( c.num_channels, c.embed_dim, @@ -88,7 +83,6 @@ impl ClipVisionEmbeddings { conv2dconfig, vs.pp("patch_embedding"), )?; - Ok(Self { patch_embedding, position_ids, @@ -101,31 +95,21 @@ impl ClipVisionEmbeddings { impl Module for ClipVisionEmbeddings { fn forward(&self, pixel_values: &Tensor) -> Result { let batch_size = pixel_values.shape().dims(); - - let patch_embeds = self.patch_embedding.forward(&pixel_values)?; - - let patch_embeds = patch_embeds.flatten_from(2)?; - - let patch_embeds = patch_embeds.transpose(1, 2)?; - - let class_embedding = self.class_embedding.clone(); - - let shape = Shape::from(vec![batch_size[0], 1, class_embedding.dim(D::Minus1)?]); - - let class_embeds = class_embedding.expand(shape)?; - + let patch_embeds = self + .patch_embedding + .forward(pixel_values)? + .flatten_from(2)? + .transpose(1, 2)?; + let shape = Shape::from((batch_size[0], 1, self.class_embedding.dim(D::Minus1)?)); + let class_embeds = self.class_embedding.expand(shape)?; let embeddings = Tensor::cat(&[class_embeds, patch_embeds], 1)?; - let position_embedding = self.position_embedding.forward(&self.position_ids)?; - - let embeddings = embeddings.broadcast_add(&position_embedding)?; - - Ok(embeddings) + embeddings.broadcast_add(&position_embedding) } } // https://github.com/huggingface/transformers/blob/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip/modeling_clip.py#L743 -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct ClipVisionTransformer { embeddings: ClipVisionEmbeddings, encoder: ClipEncoder, @@ -136,13 +120,9 @@ pub struct ClipVisionTransformer { impl ClipVisionTransformer { pub fn new(vs: candle_nn::VarBuilder, c: &ClipVisionConfig) -> Result { let embeddings = ClipVisionEmbeddings::new(vs.pp("embeddings"), c)?; - let pre_layer_norm = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("pre_layrnorm"))?; - let encoder = ClipEncoder::new(vs.pp("encoder"), &EncoderConfig::Vision(c.clone()))?; - let final_layer_norm = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("post_layernorm"))?; - Ok(Self { embeddings, encoder, @@ -154,18 +134,14 @@ impl ClipVisionTransformer { impl Module for ClipVisionTransformer { fn forward(&self, pixel_values: &Tensor) -> Result { - let hidden_states = self.embeddings.forward(pixel_values)?; - - let hidden_states = self.pre_layer_norm.forward(&hidden_states)?; + let hidden_states = pixel_values + .apply(&self.embeddings)? + .apply(&self.pre_layer_norm)?; let encoder_outputs = self.encoder.forward(&hidden_states, None)?; - // https://github.com/huggingface/transformers/blob/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip/modeling_clip.py#L787 // pooled_output = encoder_outputs[:, 0, :] let pooled_output = encoder_outputs.i((.., 0, ..))?; - - let output = self.final_layer_norm.forward(&pooled_output)?; - - Ok(output) + self.final_layer_norm.forward(&pooled_output) } } diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index 4267059c..6fbc1844 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -3,6 +3,7 @@ pub mod bigcode; pub mod blip; pub mod blip_text; pub mod chatglm; +pub mod clip; pub mod convmixer; pub mod convnext; pub mod dinov2; @@ -12,7 +13,6 @@ pub mod efficientvit; pub mod encodec; pub mod falcon; pub mod gemma; -pub mod clip; pub mod jina_bert; pub mod llama; pub mod llama2_c;