Fix clippy lints + minor cleanups. (#1957)

* Fix clippy lints + minor cleanups.

* fmt.

* Derive clone.
This commit is contained in:
Laurent Mazare
2024-03-28 14:17:46 +01:00
committed by GitHub
parent b0340d72ec
commit cdc8b57b5c
4 changed files with 41 additions and 100 deletions

View File

@ -59,7 +59,7 @@ impl ClipTextConfig {
// ClipTextEmbeddings mostly based on the existing implementation in the stable diffision model.
// TODO rewrite to be more similar to https://github.com/huggingface/transformers/blob/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip/modeling_clip.py#L142
#[derive(Debug)]
#[derive(Clone, Debug)]
struct ClipTextEmbeddings {
token_embedding: candle_nn::Embedding,
position_embedding: candle_nn::Embedding,
@ -70,16 +70,13 @@ impl ClipTextEmbeddings {
fn new(vs: candle_nn::VarBuilder, c: &ClipTextConfig) -> Result<Self> {
let token_embedding =
candle_nn::embedding(c.vocab_size, c.embed_dim, vs.pp("token_embedding"))?;
let position_embedding: nn::Embedding = candle_nn::embedding(
c.max_position_embeddings,
c.embed_dim,
vs.pp("position_embedding"),
)?;
let position_ids =
Tensor::arange(0u32, c.max_position_embeddings as u32, vs.device())?.unsqueeze(0)?;
Ok(ClipTextEmbeddings {
token_embedding,
position_embedding,
@ -91,20 +88,14 @@ impl ClipTextEmbeddings {
impl Module for ClipTextEmbeddings {
fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
let seq_length = input_ids.dim(D::Minus1)?;
let inputs_embeds = &self.token_embedding.forward(input_ids)?;
let postion_ids = &self.position_ids.narrow(1, 0, seq_length)?;
let position_embedding = &self.position_embedding.forward(&postion_ids)?;
let inputs_embeds = inputs_embeds.broadcast_add(&position_embedding)?;
Ok(inputs_embeds)
let inputs_embeds = self.token_embedding.forward(input_ids)?;
let position_ids = self.position_ids.narrow(1, 0, seq_length)?;
let position_embedding = self.position_embedding.forward(&position_ids)?;
inputs_embeds.broadcast_add(&position_embedding)
}
}
#[derive(Debug)]
#[derive(Clone, Debug)]
struct ClipAttention {
k_proj: candle_nn::Linear,
v_proj: candle_nn::Linear,
@ -166,15 +157,10 @@ impl ClipAttention {
let src_len = key_states.dim(1)?;
let attn_weights = if let Some(causal_attention_mask) = causal_attention_mask {
let attn_reshape =
attn_weights.reshape((bsz, self.num_attention_heads, seq_len, src_len))?;
let attn_weights = attn_reshape.broadcast_add(causal_attention_mask)?;
let attn_weights =
attn_weights.reshape((bsz * self.num_attention_heads, seq_len, src_len))?;
attn_weights
.reshape((bsz, self.num_attention_heads, seq_len, src_len))?
.broadcast_add(causal_attention_mask)?
.reshape((bsz * self.num_attention_heads, seq_len, src_len))?
} else {
attn_weights
};
@ -190,7 +176,7 @@ impl ClipAttention {
}
}
#[derive(Debug)]
#[derive(Clone, Debug)]
struct ClipMlp {
fc1: candle_nn::Linear,
fc2: candle_nn::Linear,
@ -217,7 +203,7 @@ impl ClipMlp {
}
}
#[derive(Debug)]
#[derive(Clone, Debug)]
struct ClipEncoderLayer {
self_attn: ClipAttention,
layer_norm1: candle_nn::LayerNorm,
@ -253,7 +239,7 @@ impl ClipEncoderLayer {
}
}
#[derive(Debug)]
#[derive(Clone, Debug)]
pub struct ClipEncoder {
layers: Vec<ClipEncoderLayer>,
}
@ -271,7 +257,6 @@ impl ClipEncoder {
pub fn forward(&self, xs: &Tensor, causal_attention_mask: Option<&Tensor>) -> Result<Tensor> {
let mut xs = xs.clone();
for layer in self.layers.iter() {
xs = layer.forward(&xs, causal_attention_mask)?;
}
@ -280,7 +265,7 @@ impl ClipEncoder {
}
/// A CLIP transformer based model.
#[derive(Debug)]
#[derive(Clone, Debug)]
pub struct ClipTextTransformer {
embeddings: ClipTextEmbeddings,
encoder: ClipEncoder,
@ -292,7 +277,6 @@ impl ClipTextTransformer {
let embeddings = ClipTextEmbeddings::new(vs.pp("embeddings"), c)?;
let encoder = ClipEncoder::new(vs.pp("encoder"), &EncoderConfig::Text(c.clone()))?;
let final_layer_norm = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("final_layer_norm"))?;
Ok(ClipTextTransformer {
embeddings,
encoder,
@ -325,7 +309,6 @@ impl ClipTextTransformer {
pub fn forward_with_mask(&self, input_ids: &Tensor, mask_after: usize) -> Result<Tensor> {
let (bsz, seq_len) = input_ids.dims2()?;
let input_ids = self.embeddings.forward(input_ids)?;
let causal_attention_mask =
Self::build_causal_attention_mask(bsz, seq_len, mask_after, input_ids.device())?;
let input_ids = self
@ -338,18 +321,13 @@ impl ClipTextTransformer {
impl Module for ClipTextTransformer {
fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
let output = self.forward_with_mask(input_ids, usize::MAX)?;
let sequence_max_indices = input_ids.argmax(D::Minus1)?.to_dtype(DType::I64)?;
let mut indices: Vec<Tensor> = Vec::new();
let mut indices = Vec::new();
for (batch_idx, &seq_idx) in sequence_max_indices.to_vec1::<i64>()?.iter().enumerate() {
let index = output.i((batch_idx, seq_idx as usize))?.unsqueeze(0)?;
indices.push(index);
}
let pooled_output = Tensor::cat(&indices, 0)?;
Ok(pooled_output)
Tensor::cat(&indices, 0)
}
}