mirror of
https://github.com/huggingface/candle.git
synced 2025-06-22 12:28:06 +00:00
Fix clippy lints + minor cleanups. (#1957)
* Fix clippy lints + minor cleanups. * fmt. * Derive clone.
This commit is contained in:
@ -10,13 +10,11 @@ use self::{
|
||||
vision_model::ClipVisionTransformer,
|
||||
};
|
||||
use candle::{Result, Tensor, D};
|
||||
use candle_nn::Module;
|
||||
|
||||
use tracing::warn;
|
||||
|
||||
pub mod text_model;
|
||||
pub mod vision_model;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ClipModel {
|
||||
text_model: ClipTextTransformer,
|
||||
vision_model: ClipVisionTransformer,
|
||||
@ -25,6 +23,7 @@ pub struct ClipModel {
|
||||
logit_scale: Tensor,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum EncoderConfig {
|
||||
Text(text_model::ClipTextConfig),
|
||||
Vision(vision_model::ClipVisionConfig),
|
||||
@ -67,6 +66,7 @@ impl EncoderConfig {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ClipConfig {
|
||||
pub text_config: text_model::ClipTextConfig,
|
||||
pub vision_config: vision_model::ClipVisionConfig,
|
||||
@ -111,7 +111,6 @@ impl ClipModel {
|
||||
let logit_scale = if vs.contains_tensor("logit_scale") {
|
||||
vs.get(&[], "logit_scale")?
|
||||
} else {
|
||||
warn!("Creating logit_scale tensor, results may vary.");
|
||||
Tensor::new(&[c.logit_scale_init_value], vs.device())?
|
||||
};
|
||||
|
||||
@ -125,38 +124,26 @@ impl ClipModel {
|
||||
}
|
||||
|
||||
pub fn get_text_features(&self, input_ids: &Tensor) -> Result<Tensor> {
|
||||
let text_outputs = self.text_model.forward(input_ids)?;
|
||||
|
||||
let text_features = self.text_projection.forward(&text_outputs)?;
|
||||
|
||||
Ok(text_features)
|
||||
input_ids
|
||||
.apply(&self.text_model)?
|
||||
.apply(&self.text_projection)
|
||||
}
|
||||
|
||||
pub fn get_image_features(&self, pixel_values: &Tensor) -> Result<Tensor> {
|
||||
let image_features = self.vision_model.forward(pixel_values)?;
|
||||
|
||||
let image_features = self.visual_projection.forward(&image_features)?;
|
||||
|
||||
Ok(image_features)
|
||||
pixel_values
|
||||
.apply(&self.vision_model)?
|
||||
.apply(&self.visual_projection)
|
||||
}
|
||||
|
||||
pub fn forward(&self, pixel_values: &Tensor, input_ids: &Tensor) -> Result<(Tensor, Tensor)> {
|
||||
let image_features = self.get_image_features(pixel_values)?;
|
||||
|
||||
let text_features = self.get_text_features(input_ids)?;
|
||||
|
||||
let image_features_normalized = div_l2_norm(&image_features)?;
|
||||
|
||||
let text_features_normalized = div_l2_norm(&text_features)?;
|
||||
|
||||
let logits_per_text = text_features_normalized.matmul(&image_features_normalized.t()?)?;
|
||||
|
||||
let logit_scale = &self.logit_scale.exp()?;
|
||||
|
||||
let logit_scale = self.logit_scale.exp()?;
|
||||
let logits_per_text = logits_per_text.broadcast_mul(&logit_scale)?;
|
||||
|
||||
let logits_per_image = logits_per_text.t()?;
|
||||
|
||||
Ok((logits_per_text, logits_per_image))
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user