Fix clippy lints + minor cleanups. (#1957)

* Fix clippy lints + minor cleanups.

* fmt.

* Derive clone.
This commit is contained in:
Laurent Mazare
2024-03-28 14:17:46 +01:00
committed by GitHub
parent b0340d72ec
commit cdc8b57b5c
4 changed files with 41 additions and 100 deletions

View File

@ -10,13 +10,11 @@ use self::{
vision_model::ClipVisionTransformer,
};
use candle::{Result, Tensor, D};
use candle_nn::Module;
use tracing::warn;
pub mod text_model;
pub mod vision_model;
#[derive(Clone, Debug)]
pub struct ClipModel {
text_model: ClipTextTransformer,
vision_model: ClipVisionTransformer,
@ -25,6 +23,7 @@ pub struct ClipModel {
logit_scale: Tensor,
}
#[derive(Clone, Debug)]
pub enum EncoderConfig {
Text(text_model::ClipTextConfig),
Vision(vision_model::ClipVisionConfig),
@ -67,6 +66,7 @@ impl EncoderConfig {
}
}
#[derive(Clone, Debug)]
pub struct ClipConfig {
pub text_config: text_model::ClipTextConfig,
pub vision_config: vision_model::ClipVisionConfig,
@ -111,7 +111,6 @@ impl ClipModel {
let logit_scale = if vs.contains_tensor("logit_scale") {
vs.get(&[], "logit_scale")?
} else {
warn!("Creating logit_scale tensor, results may vary.");
Tensor::new(&[c.logit_scale_init_value], vs.device())?
};
@ -125,38 +124,26 @@ impl ClipModel {
}
pub fn get_text_features(&self, input_ids: &Tensor) -> Result<Tensor> {
let text_outputs = self.text_model.forward(input_ids)?;
let text_features = self.text_projection.forward(&text_outputs)?;
Ok(text_features)
input_ids
.apply(&self.text_model)?
.apply(&self.text_projection)
}
pub fn get_image_features(&self, pixel_values: &Tensor) -> Result<Tensor> {
let image_features = self.vision_model.forward(pixel_values)?;
let image_features = self.visual_projection.forward(&image_features)?;
Ok(image_features)
pixel_values
.apply(&self.vision_model)?
.apply(&self.visual_projection)
}
pub fn forward(&self, pixel_values: &Tensor, input_ids: &Tensor) -> Result<(Tensor, Tensor)> {
let image_features = self.get_image_features(pixel_values)?;
let text_features = self.get_text_features(input_ids)?;
let image_features_normalized = div_l2_norm(&image_features)?;
let text_features_normalized = div_l2_norm(&text_features)?;
let logits_per_text = text_features_normalized.matmul(&image_features_normalized.t()?)?;
let logit_scale = &self.logit_scale.exp()?;
let logit_scale = self.logit_scale.exp()?;
let logits_per_text = logits_per_text.broadcast_mul(&logit_scale)?;
let logits_per_image = logits_per_text.t()?;
Ok((logits_per_text, logits_per_image))
}
}