mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
CLIP model implementation with example (#1950)
* CLIP model implementation with example * CLIP Implementation fixes, batch images * CLIP model remove images from git * CLIP model remove unnecessary use of batch_indices
This commit is contained in:

committed by
GitHub

parent
b3484e7a5e
commit
b0340d72ec
46
candle-examples/examples/clip/README.md
Normal file
46
candle-examples/examples/clip/README.md
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
Contrastive Language-Image Pre-Training
|
||||||
|
|
||||||
|
Contrastive Language-Image Pre-Training (CLIP) is an architecture trained on
|
||||||
|
pairs of images with related texts.
|
||||||
|
|
||||||
|
https://github.com/openai/CLIP
|
||||||
|
|
||||||
|
https://github.com/huggingface/transformers/tree/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip
|
||||||
|
|
||||||
|
## Running on an example on cpu
|
||||||
|
|
||||||
|
```
|
||||||
|
$ cargo run --example clip --release -- --images "candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg","candle-examples/examples/yolo-v8/assets/bike.jpg" --cpu --sequences "a cycling race","a photo of two cats","a robot holding a candle"
|
||||||
|
|
||||||
|
|
||||||
|
Results for image: candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg
|
||||||
|
|
||||||
|
INFO clip: Probability: 0.0000% Text: a cycling race
|
||||||
|
INFO clip: Probability: 0.0000% Text: a photo of two cats
|
||||||
|
INFO clip: Probability: 100.0000% Text: a robot holding a candle
|
||||||
|
|
||||||
|
Results for image: candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||||
|
|
||||||
|
INFO clip: Probability: 99.9999% Text: a cycling race
|
||||||
|
INFO clip: Probability: 0.0001% Text: a photo of two cats
|
||||||
|
INFO clip: Probability: 0.0000% Text: a robot holding a candle
|
||||||
|
```
|
||||||
|
|
||||||
|
## Running on an example with metal feature (mac)
|
||||||
|
|
||||||
|
```
|
||||||
|
$ cargo run --features metal --example clip --release -- --images "candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg","candle-examples/examples/yolo-v8/assets/bike.jpg" --cpu --sequences "a cycling race","a photo of two cats","a robot holding a candle"
|
||||||
|
|
||||||
|
|
||||||
|
Results for image: candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg
|
||||||
|
|
||||||
|
INFO clip: Probability: 0.0000% Text: a cycling race
|
||||||
|
INFO clip: Probability: 0.0000% Text: a photo of two cats
|
||||||
|
INFO clip: Probability: 100.0000% Text: a robot holding a candle
|
||||||
|
|
||||||
|
Results for image: candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||||
|
|
||||||
|
INFO clip: Probability: 99.9999% Text: a cycling race
|
||||||
|
INFO clip: Probability: 0.0001% Text: a photo of two cats
|
||||||
|
INFO clip: Probability: 0.0000% Text: a robot holding a candle
|
||||||
|
```
|
202
candle-examples/examples/clip/main.rs
Normal file
202
candle-examples/examples/clip/main.rs
Normal file
@ -0,0 +1,202 @@
|
|||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
|
||||||
|
use anyhow::Error as E;
|
||||||
|
use clap::Parser;
|
||||||
|
|
||||||
|
use candle::{DType, Device, Tensor};
|
||||||
|
use candle_nn::{ops::softmax, VarBuilder};
|
||||||
|
use candle_transformers::models::clip;
|
||||||
|
|
||||||
|
use tokenizers::Tokenizer;
|
||||||
|
use tracing::info;
|
||||||
|
|
||||||
|
#[derive(Parser)]
|
||||||
|
struct Args {
|
||||||
|
#[arg(long)]
|
||||||
|
model: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
tokenizer: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long, use_value_delimiter = true)]
|
||||||
|
images: Option<Vec<String>>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
|
||||||
|
#[arg(long, use_value_delimiter = true)]
|
||||||
|
sequences: Option<Vec<String>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn load_image<T: AsRef<std::path::Path>>(path: T, image_size: usize) -> anyhow::Result<Tensor> {
|
||||||
|
let img = image::io::Reader::open(path)?.decode()?;
|
||||||
|
let (height, width) = (image_size, image_size);
|
||||||
|
let img = img.resize_to_fill(
|
||||||
|
width as u32,
|
||||||
|
height as u32,
|
||||||
|
image::imageops::FilterType::Triangle,
|
||||||
|
);
|
||||||
|
|
||||||
|
let img = img.to_rgb8();
|
||||||
|
|
||||||
|
let img = img.into_raw();
|
||||||
|
let img = Tensor::from_vec(img, (height, width, 3), &Device::Cpu)?
|
||||||
|
.permute((2, 0, 1))?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.affine(2. / 255., -1.)?;
|
||||||
|
// .unsqueeze(0)?;
|
||||||
|
Ok(img)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn load_images<T: AsRef<std::path::Path>>(
|
||||||
|
paths: &Vec<T>,
|
||||||
|
image_size: usize,
|
||||||
|
) -> anyhow::Result<Tensor> {
|
||||||
|
let mut images = vec![];
|
||||||
|
|
||||||
|
for path in paths {
|
||||||
|
let tensor = load_image(path, image_size)?;
|
||||||
|
images.push(tensor);
|
||||||
|
}
|
||||||
|
|
||||||
|
let images = Tensor::stack(&images, 0)?;
|
||||||
|
|
||||||
|
Ok(images)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn main() -> anyhow::Result<()> {
|
||||||
|
// std::env::set_var("RUST_BACKTRACE", "full");
|
||||||
|
|
||||||
|
let args = Args::parse();
|
||||||
|
|
||||||
|
tracing_subscriber::fmt::init();
|
||||||
|
|
||||||
|
let model_file = match args.model {
|
||||||
|
None => {
|
||||||
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
|
|
||||||
|
let api = api.repo(hf_hub::Repo::with_revision(
|
||||||
|
"openai/clip-vit-base-patch32".to_string(),
|
||||||
|
hf_hub::RepoType::Model,
|
||||||
|
"refs/pr/15".to_string(),
|
||||||
|
));
|
||||||
|
|
||||||
|
api.get("model.safetensors")?
|
||||||
|
}
|
||||||
|
Some(model) => model.into(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let tokenizer = get_tokenizer(args.tokenizer)?;
|
||||||
|
|
||||||
|
let config = clip::ClipConfig::vit_base_patch32();
|
||||||
|
|
||||||
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
|
||||||
|
let vec_imgs = match args.images {
|
||||||
|
Some(imgs) => imgs,
|
||||||
|
None => vec![
|
||||||
|
"candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg".to_string(),
|
||||||
|
"candle-examples/examples/yolo-v8/assets/bike.jpg".to_string(),
|
||||||
|
],
|
||||||
|
};
|
||||||
|
|
||||||
|
// let image = load_image(args.image, config.image_size)?.to_device(&device)?;
|
||||||
|
let images = load_images(&vec_imgs, config.image_size)?.to_device(&device)?;
|
||||||
|
|
||||||
|
let vb =
|
||||||
|
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file.clone()], DType::F32, &device)? };
|
||||||
|
|
||||||
|
let model = clip::ClipModel::new(vb, &config)?;
|
||||||
|
|
||||||
|
let (input_ids, vec_seq) = tokenize_sequences(args.sequences, &tokenizer, &device)?;
|
||||||
|
|
||||||
|
let (_logits_per_text, logits_per_image) = model.forward(&images, &input_ids)?;
|
||||||
|
|
||||||
|
let softmax_image = softmax(&logits_per_image, 1)?;
|
||||||
|
|
||||||
|
let softmax_image_vec = softmax_image.flatten_all()?.to_vec1::<f32>()?;
|
||||||
|
|
||||||
|
info!("softmax_image_vec: {:?}", softmax_image_vec);
|
||||||
|
|
||||||
|
let probability_vec = softmax_image_vec
|
||||||
|
.iter()
|
||||||
|
.map(|v| v * 100.0)
|
||||||
|
.collect::<Vec<f32>>();
|
||||||
|
|
||||||
|
let probability_per_image = probability_vec.len() / vec_imgs.len();
|
||||||
|
|
||||||
|
for (i, img) in vec_imgs.iter().enumerate() {
|
||||||
|
let start = i * probability_per_image;
|
||||||
|
let end = start + probability_per_image;
|
||||||
|
let prob = &probability_vec[start..end];
|
||||||
|
info!("\n\nResults for image: {}\n", img);
|
||||||
|
|
||||||
|
for (i, p) in prob.iter().enumerate() {
|
||||||
|
info!("Probability: {:.4}% Text: {} ", p, vec_seq[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_tokenizer(tokenizer: Option<String>) -> anyhow::Result<Tokenizer> {
|
||||||
|
let tokenizer = match tokenizer {
|
||||||
|
None => {
|
||||||
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
|
let api = api.repo(hf_hub::Repo::with_revision(
|
||||||
|
"openai/clip-vit-base-patch32".to_string(),
|
||||||
|
hf_hub::RepoType::Model,
|
||||||
|
"refs/pr/15".to_string(),
|
||||||
|
));
|
||||||
|
api.get("tokenizer.json")?
|
||||||
|
}
|
||||||
|
Some(file) => file.into(),
|
||||||
|
};
|
||||||
|
|
||||||
|
Tokenizer::from_file(tokenizer).map_err(E::msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn tokenize_sequences(
|
||||||
|
sequences: Option<Vec<String>>,
|
||||||
|
tokenizer: &Tokenizer,
|
||||||
|
device: &Device,
|
||||||
|
) -> anyhow::Result<(Tensor, Vec<String>)> {
|
||||||
|
let pad_id = *tokenizer
|
||||||
|
.get_vocab(true)
|
||||||
|
.get("<|endoftext|>")
|
||||||
|
.ok_or(E::msg("No pad token"))?;
|
||||||
|
|
||||||
|
let vec_seq = match sequences {
|
||||||
|
Some(seq) => seq,
|
||||||
|
None => vec![
|
||||||
|
"a cycling race".to_string(),
|
||||||
|
"a photo of two cats".to_string(),
|
||||||
|
"a robot holding a candle".to_string(),
|
||||||
|
],
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut tokens = vec![];
|
||||||
|
|
||||||
|
for seq in vec_seq.clone() {
|
||||||
|
let encoding = tokenizer.encode(seq, true).map_err(E::msg)?;
|
||||||
|
tokens.push(encoding.get_ids().to_vec());
|
||||||
|
}
|
||||||
|
|
||||||
|
let max_len = tokens.iter().map(|v| v.len()).max().unwrap_or(0);
|
||||||
|
|
||||||
|
// Pad the sequences to have the same length
|
||||||
|
for token_vec in tokens.iter_mut() {
|
||||||
|
let len_diff = max_len - token_vec.len();
|
||||||
|
if len_diff > 0 {
|
||||||
|
token_vec.extend(vec![pad_id; len_diff]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let input_ids = Tensor::new(tokens, device)?;
|
||||||
|
|
||||||
|
Ok((input_ids, vec_seq))
|
||||||
|
}
|
167
candle-transformers/src/models/clip/mod.rs
Normal file
167
candle-transformers/src/models/clip/mod.rs
Normal file
@ -0,0 +1,167 @@
|
|||||||
|
//! Contrastive Language-Image Pre-Training
|
||||||
|
//!
|
||||||
|
//! Contrastive Language-Image Pre-Training (CLIP) is an architecture trained on
|
||||||
|
//! pairs of images with related texts.
|
||||||
|
//!
|
||||||
|
//! https://github.com/openai/CLIP
|
||||||
|
//! https://github.com/huggingface/transformers/tree/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip
|
||||||
|
use self::{
|
||||||
|
text_model::{Activation, ClipTextTransformer},
|
||||||
|
vision_model::ClipVisionTransformer,
|
||||||
|
};
|
||||||
|
use candle::{Result, Tensor, D};
|
||||||
|
use candle_nn::Module;
|
||||||
|
|
||||||
|
use tracing::warn;
|
||||||
|
|
||||||
|
pub mod text_model;
|
||||||
|
pub mod vision_model;
|
||||||
|
|
||||||
|
pub struct ClipModel {
|
||||||
|
text_model: ClipTextTransformer,
|
||||||
|
vision_model: ClipVisionTransformer,
|
||||||
|
visual_projection: candle_nn::Linear,
|
||||||
|
text_projection: candle_nn::Linear,
|
||||||
|
logit_scale: Tensor,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub enum EncoderConfig {
|
||||||
|
Text(text_model::ClipTextConfig),
|
||||||
|
Vision(vision_model::ClipVisionConfig),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl EncoderConfig {
|
||||||
|
pub fn embed_dim(&self) -> usize {
|
||||||
|
match self {
|
||||||
|
Self::Text(c) => c.embed_dim,
|
||||||
|
Self::Vision(c) => c.embed_dim,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn num_attention_heads(&self) -> usize {
|
||||||
|
match self {
|
||||||
|
Self::Text(c) => c.num_attention_heads,
|
||||||
|
Self::Vision(c) => c.num_attention_heads,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn intermediate_size(&self) -> usize {
|
||||||
|
match self {
|
||||||
|
Self::Text(c) => c.intermediate_size,
|
||||||
|
Self::Vision(c) => c.intermediate_size,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn num_hidden_layers(&self) -> usize {
|
||||||
|
match self {
|
||||||
|
Self::Text(c) => c.num_hidden_layers,
|
||||||
|
Self::Vision(c) => c.num_hidden_layers,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn activation(&self) -> Activation {
|
||||||
|
match self {
|
||||||
|
Self::Text(_c) => Activation::QuickGelu,
|
||||||
|
Self::Vision(c) => c.activation,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct ClipConfig {
|
||||||
|
pub text_config: text_model::ClipTextConfig,
|
||||||
|
pub vision_config: vision_model::ClipVisionConfig,
|
||||||
|
pub logit_scale_init_value: f32,
|
||||||
|
pub image_size: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ClipConfig {
|
||||||
|
// base image size is 224, model size is 600Mb
|
||||||
|
pub fn vit_base_patch32() -> Self {
|
||||||
|
let text_config = text_model::ClipTextConfig::vit_base_patch32();
|
||||||
|
let vision_config = vision_model::ClipVisionConfig::vit_base_patch32();
|
||||||
|
|
||||||
|
Self {
|
||||||
|
text_config,
|
||||||
|
vision_config,
|
||||||
|
logit_scale_init_value: 2.6592,
|
||||||
|
image_size: 224,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ClipModel {
|
||||||
|
pub fn new(vs: candle_nn::VarBuilder, c: &ClipConfig) -> Result<Self> {
|
||||||
|
let text_model = ClipTextTransformer::new(vs.pp("text_model"), &c.text_config)?;
|
||||||
|
|
||||||
|
let vision_model = ClipVisionTransformer::new(vs.pp("vision_model"), &c.vision_config)?;
|
||||||
|
|
||||||
|
let visual_projection = candle_nn::linear_no_bias(
|
||||||
|
c.vision_config.embed_dim,
|
||||||
|
c.vision_config.projection_dim,
|
||||||
|
vs.pp("visual_projection"),
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let text_projection = candle_nn::linear_no_bias(
|
||||||
|
c.text_config.embed_dim,
|
||||||
|
c.text_config.projection_dim,
|
||||||
|
vs.pp("text_projection"),
|
||||||
|
)?;
|
||||||
|
|
||||||
|
// originally nn.Parameter
|
||||||
|
let logit_scale = if vs.contains_tensor("logit_scale") {
|
||||||
|
vs.get(&[], "logit_scale")?
|
||||||
|
} else {
|
||||||
|
warn!("Creating logit_scale tensor, results may vary.");
|
||||||
|
Tensor::new(&[c.logit_scale_init_value], vs.device())?
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
text_model,
|
||||||
|
vision_model,
|
||||||
|
visual_projection,
|
||||||
|
text_projection,
|
||||||
|
logit_scale,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_text_features(&self, input_ids: &Tensor) -> Result<Tensor> {
|
||||||
|
let text_outputs = self.text_model.forward(input_ids)?;
|
||||||
|
|
||||||
|
let text_features = self.text_projection.forward(&text_outputs)?;
|
||||||
|
|
||||||
|
Ok(text_features)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_image_features(&self, pixel_values: &Tensor) -> Result<Tensor> {
|
||||||
|
let image_features = self.vision_model.forward(pixel_values)?;
|
||||||
|
|
||||||
|
let image_features = self.visual_projection.forward(&image_features)?;
|
||||||
|
|
||||||
|
Ok(image_features)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(&self, pixel_values: &Tensor, input_ids: &Tensor) -> Result<(Tensor, Tensor)> {
|
||||||
|
let image_features = self.get_image_features(pixel_values)?;
|
||||||
|
|
||||||
|
let text_features = self.get_text_features(input_ids)?;
|
||||||
|
|
||||||
|
let image_features_normalized = div_l2_norm(&image_features)?;
|
||||||
|
|
||||||
|
let text_features_normalized = div_l2_norm(&text_features)?;
|
||||||
|
|
||||||
|
let logits_per_text = text_features_normalized.matmul(&image_features_normalized.t()?)?;
|
||||||
|
|
||||||
|
let logit_scale = &self.logit_scale.exp()?;
|
||||||
|
|
||||||
|
let logits_per_text = logits_per_text.broadcast_mul(&logit_scale)?;
|
||||||
|
|
||||||
|
let logits_per_image = logits_per_text.t()?;
|
||||||
|
|
||||||
|
Ok((logits_per_text, logits_per_image))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn div_l2_norm(v: &Tensor) -> Result<Tensor> {
|
||||||
|
let l2_norm = v.sqr()?.sum_keepdim(D::Minus1)?.sqrt()?;
|
||||||
|
v.broadcast_div(&l2_norm)
|
||||||
|
}
|
355
candle-transformers/src/models/clip/text_model.rs
Normal file
355
candle-transformers/src/models/clip/text_model.rs
Normal file
@ -0,0 +1,355 @@
|
|||||||
|
//! Contrastive Language-Image Pre-Training
|
||||||
|
//!
|
||||||
|
//! Contrastive Language-Image Pre-Training (CLIP) is an architecture trained on
|
||||||
|
//! pairs of images with related texts.
|
||||||
|
//!
|
||||||
|
//! https://github.com/openai/CLIP
|
||||||
|
//! https://github.com/huggingface/transformers/tree/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip
|
||||||
|
|
||||||
|
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||||
|
use candle_nn as nn;
|
||||||
|
use candle_nn::Module;
|
||||||
|
|
||||||
|
use super::EncoderConfig;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
pub enum Activation {
|
||||||
|
QuickGelu,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for Activation {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
match self {
|
||||||
|
Activation::QuickGelu => xs * nn::ops::sigmoid(&(xs * 1.702f64)?)?,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct ClipTextConfig {
|
||||||
|
pub vocab_size: usize,
|
||||||
|
pub embed_dim: usize,
|
||||||
|
pub activation: Activation,
|
||||||
|
pub intermediate_size: usize,
|
||||||
|
pub max_position_embeddings: usize,
|
||||||
|
pub pad_with: Option<String>,
|
||||||
|
pub num_hidden_layers: usize,
|
||||||
|
pub num_attention_heads: usize,
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub projection_dim: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ClipTextConfig {
|
||||||
|
// The config details can be found in the "text_config" section of this json file:
|
||||||
|
// https://huggingface.co/openai/clip-vit-large-patch14/blob/main/config.json
|
||||||
|
pub fn vit_base_patch32() -> Self {
|
||||||
|
Self {
|
||||||
|
vocab_size: 49408,
|
||||||
|
embed_dim: 512,
|
||||||
|
intermediate_size: 2048,
|
||||||
|
max_position_embeddings: 77,
|
||||||
|
pad_with: None,
|
||||||
|
num_hidden_layers: 12,
|
||||||
|
num_attention_heads: 8,
|
||||||
|
projection_dim: 512,
|
||||||
|
activation: Activation::QuickGelu,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClipTextEmbeddings mostly based on the existing implementation in the stable diffision model.
|
||||||
|
// TODO rewrite to be more similar to https://github.com/huggingface/transformers/blob/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip/modeling_clip.py#L142
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct ClipTextEmbeddings {
|
||||||
|
token_embedding: candle_nn::Embedding,
|
||||||
|
position_embedding: candle_nn::Embedding,
|
||||||
|
position_ids: Tensor,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ClipTextEmbeddings {
|
||||||
|
fn new(vs: candle_nn::VarBuilder, c: &ClipTextConfig) -> Result<Self> {
|
||||||
|
let token_embedding =
|
||||||
|
candle_nn::embedding(c.vocab_size, c.embed_dim, vs.pp("token_embedding"))?;
|
||||||
|
|
||||||
|
let position_embedding: nn::Embedding = candle_nn::embedding(
|
||||||
|
c.max_position_embeddings,
|
||||||
|
c.embed_dim,
|
||||||
|
vs.pp("position_embedding"),
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let position_ids =
|
||||||
|
Tensor::arange(0u32, c.max_position_embeddings as u32, vs.device())?.unsqueeze(0)?;
|
||||||
|
|
||||||
|
Ok(ClipTextEmbeddings {
|
||||||
|
token_embedding,
|
||||||
|
position_embedding,
|
||||||
|
position_ids,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for ClipTextEmbeddings {
|
||||||
|
fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
|
||||||
|
let seq_length = input_ids.dim(D::Minus1)?;
|
||||||
|
|
||||||
|
let inputs_embeds = &self.token_embedding.forward(input_ids)?;
|
||||||
|
|
||||||
|
let postion_ids = &self.position_ids.narrow(1, 0, seq_length)?;
|
||||||
|
|
||||||
|
let position_embedding = &self.position_embedding.forward(&postion_ids)?;
|
||||||
|
|
||||||
|
let inputs_embeds = inputs_embeds.broadcast_add(&position_embedding)?;
|
||||||
|
|
||||||
|
Ok(inputs_embeds)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct ClipAttention {
|
||||||
|
k_proj: candle_nn::Linear,
|
||||||
|
v_proj: candle_nn::Linear,
|
||||||
|
q_proj: candle_nn::Linear,
|
||||||
|
out_proj: candle_nn::Linear,
|
||||||
|
head_dim: usize,
|
||||||
|
scale: f64,
|
||||||
|
num_attention_heads: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ClipAttention {
|
||||||
|
fn new(vs: candle_nn::VarBuilder, c: &EncoderConfig) -> Result<Self> {
|
||||||
|
let embed_dim = c.embed_dim();
|
||||||
|
let num_attention_heads = c.num_attention_heads();
|
||||||
|
let k_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("k_proj"))?;
|
||||||
|
let v_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("v_proj"))?;
|
||||||
|
let q_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("q_proj"))?;
|
||||||
|
let out_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("out_proj"))?;
|
||||||
|
let head_dim = embed_dim / num_attention_heads;
|
||||||
|
let scale = (head_dim as f64).powf(-0.5);
|
||||||
|
|
||||||
|
Ok(ClipAttention {
|
||||||
|
k_proj,
|
||||||
|
v_proj,
|
||||||
|
q_proj,
|
||||||
|
out_proj,
|
||||||
|
head_dim,
|
||||||
|
scale,
|
||||||
|
num_attention_heads,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn shape(&self, xs: &Tensor, seq_len: usize, bsz: usize) -> Result<Tensor> {
|
||||||
|
xs.reshape((bsz, seq_len, self.num_attention_heads, self.head_dim))?
|
||||||
|
.transpose(1, 2)?
|
||||||
|
.contiguous()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(&self, xs: &Tensor, causal_attention_mask: Option<&Tensor>) -> Result<Tensor> {
|
||||||
|
let in_dtype = xs.dtype();
|
||||||
|
let (bsz, seq_len, embed_dim) = xs.dims3()?;
|
||||||
|
|
||||||
|
let query_states = (self.q_proj.forward(xs)? * self.scale)?;
|
||||||
|
let proj_shape = (bsz * self.num_attention_heads, seq_len, self.head_dim);
|
||||||
|
let query_states = self
|
||||||
|
.shape(&query_states, seq_len, bsz)?
|
||||||
|
.reshape(proj_shape)?
|
||||||
|
.to_dtype(DType::F32)?;
|
||||||
|
let key_states = self
|
||||||
|
.shape(&self.k_proj.forward(xs)?, seq_len, bsz)?
|
||||||
|
.reshape(proj_shape)?
|
||||||
|
.to_dtype(DType::F32)?;
|
||||||
|
let value_states = self
|
||||||
|
.shape(&self.v_proj.forward(xs)?, seq_len, bsz)?
|
||||||
|
.reshape(proj_shape)?
|
||||||
|
.to_dtype(DType::F32)?;
|
||||||
|
let attn_weights = query_states.matmul(&key_states.transpose(1, 2)?)?;
|
||||||
|
|
||||||
|
let src_len = key_states.dim(1)?;
|
||||||
|
|
||||||
|
let attn_weights = if let Some(causal_attention_mask) = causal_attention_mask {
|
||||||
|
let attn_reshape =
|
||||||
|
attn_weights.reshape((bsz, self.num_attention_heads, seq_len, src_len))?;
|
||||||
|
|
||||||
|
let attn_weights = attn_reshape.broadcast_add(causal_attention_mask)?;
|
||||||
|
|
||||||
|
let attn_weights =
|
||||||
|
attn_weights.reshape((bsz * self.num_attention_heads, seq_len, src_len))?;
|
||||||
|
|
||||||
|
attn_weights
|
||||||
|
} else {
|
||||||
|
attn_weights
|
||||||
|
};
|
||||||
|
|
||||||
|
let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?;
|
||||||
|
|
||||||
|
let attn_output = attn_weights.matmul(&value_states)?.to_dtype(in_dtype)?;
|
||||||
|
let attn_output = attn_output
|
||||||
|
.reshape((bsz, self.num_attention_heads, seq_len, self.head_dim))?
|
||||||
|
.transpose(1, 2)?
|
||||||
|
.reshape((bsz, seq_len, embed_dim))?;
|
||||||
|
self.out_proj.forward(&attn_output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct ClipMlp {
|
||||||
|
fc1: candle_nn::Linear,
|
||||||
|
fc2: candle_nn::Linear,
|
||||||
|
activation: Activation,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ClipMlp {
|
||||||
|
fn new(vs: candle_nn::VarBuilder, c: &EncoderConfig) -> Result<Self> {
|
||||||
|
let fc1 = candle_nn::linear(c.embed_dim(), c.intermediate_size(), vs.pp("fc1"))?;
|
||||||
|
let fc2 = candle_nn::linear(c.intermediate_size(), c.embed_dim(), vs.pp("fc2"))?;
|
||||||
|
|
||||||
|
Ok(ClipMlp {
|
||||||
|
fc1,
|
||||||
|
fc2,
|
||||||
|
activation: c.activation(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ClipMlp {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let xs = self.fc1.forward(xs)?;
|
||||||
|
self.fc2.forward(&self.activation.forward(&xs)?)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct ClipEncoderLayer {
|
||||||
|
self_attn: ClipAttention,
|
||||||
|
layer_norm1: candle_nn::LayerNorm,
|
||||||
|
mlp: ClipMlp,
|
||||||
|
layer_norm2: candle_nn::LayerNorm,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ClipEncoderLayer {
|
||||||
|
fn new(vs: candle_nn::VarBuilder, c: &EncoderConfig) -> Result<Self> {
|
||||||
|
let self_attn = ClipAttention::new(vs.pp("self_attn"), c)?;
|
||||||
|
let layer_norm1 = candle_nn::layer_norm(c.embed_dim(), 1e-5, vs.pp("layer_norm1"))?;
|
||||||
|
let mlp = ClipMlp::new(vs.pp("mlp"), c)?;
|
||||||
|
let layer_norm2 = candle_nn::layer_norm(c.embed_dim(), 1e-5, vs.pp("layer_norm2"))?;
|
||||||
|
|
||||||
|
Ok(ClipEncoderLayer {
|
||||||
|
self_attn,
|
||||||
|
layer_norm1,
|
||||||
|
mlp,
|
||||||
|
layer_norm2,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(&self, xs: &Tensor, causal_attention_mask: Option<&Tensor>) -> Result<Tensor> {
|
||||||
|
let residual = xs;
|
||||||
|
let xs = self.layer_norm1.forward(xs)?;
|
||||||
|
let xs = self.self_attn.forward(&xs, causal_attention_mask)?;
|
||||||
|
let xs = (xs + residual)?;
|
||||||
|
|
||||||
|
let residual = &xs;
|
||||||
|
let xs = self.layer_norm2.forward(&xs)?;
|
||||||
|
let xs = self.mlp.forward(&xs)?;
|
||||||
|
xs + residual
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct ClipEncoder {
|
||||||
|
layers: Vec<ClipEncoderLayer>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ClipEncoder {
|
||||||
|
pub fn new(vs: candle_nn::VarBuilder, c: &EncoderConfig) -> Result<Self> {
|
||||||
|
let vs = vs.pp("layers");
|
||||||
|
let mut layers: Vec<ClipEncoderLayer> = Vec::new();
|
||||||
|
for index in 0..c.num_hidden_layers() {
|
||||||
|
let layer = ClipEncoderLayer::new(vs.pp(&index.to_string()), c)?;
|
||||||
|
layers.push(layer)
|
||||||
|
}
|
||||||
|
Ok(ClipEncoder { layers })
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(&self, xs: &Tensor, causal_attention_mask: Option<&Tensor>) -> Result<Tensor> {
|
||||||
|
let mut xs = xs.clone();
|
||||||
|
|
||||||
|
for layer in self.layers.iter() {
|
||||||
|
xs = layer.forward(&xs, causal_attention_mask)?;
|
||||||
|
}
|
||||||
|
Ok(xs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A CLIP transformer based model.
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct ClipTextTransformer {
|
||||||
|
embeddings: ClipTextEmbeddings,
|
||||||
|
encoder: ClipEncoder,
|
||||||
|
final_layer_norm: candle_nn::LayerNorm,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ClipTextTransformer {
|
||||||
|
pub fn new(vs: candle_nn::VarBuilder, c: &ClipTextConfig) -> Result<Self> {
|
||||||
|
let embeddings = ClipTextEmbeddings::new(vs.pp("embeddings"), c)?;
|
||||||
|
let encoder = ClipEncoder::new(vs.pp("encoder"), &EncoderConfig::Text(c.clone()))?;
|
||||||
|
let final_layer_norm = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("final_layer_norm"))?;
|
||||||
|
|
||||||
|
Ok(ClipTextTransformer {
|
||||||
|
embeddings,
|
||||||
|
encoder,
|
||||||
|
final_layer_norm,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: rewrrite to newer version
|
||||||
|
fn build_causal_attention_mask(
|
||||||
|
bsz: usize,
|
||||||
|
seq_len: usize,
|
||||||
|
mask_after: usize,
|
||||||
|
device: &Device,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
let mask: Vec<_> = (0..seq_len)
|
||||||
|
.flat_map(|i| {
|
||||||
|
(0..seq_len).map(move |j| {
|
||||||
|
if j > i || j > mask_after {
|
||||||
|
f32::MIN
|
||||||
|
} else {
|
||||||
|
0.
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
let mask = Tensor::from_slice(&mask, (seq_len, seq_len), device)?;
|
||||||
|
mask.broadcast_as((bsz, 1, seq_len, seq_len))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward_with_mask(&self, input_ids: &Tensor, mask_after: usize) -> Result<Tensor> {
|
||||||
|
let (bsz, seq_len) = input_ids.dims2()?;
|
||||||
|
let input_ids = self.embeddings.forward(input_ids)?;
|
||||||
|
|
||||||
|
let causal_attention_mask =
|
||||||
|
Self::build_causal_attention_mask(bsz, seq_len, mask_after, input_ids.device())?;
|
||||||
|
let input_ids = self
|
||||||
|
.encoder
|
||||||
|
.forward(&input_ids, Some(&causal_attention_mask))?;
|
||||||
|
self.final_layer_norm.forward(&input_ids)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for ClipTextTransformer {
|
||||||
|
fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
|
||||||
|
let output = self.forward_with_mask(input_ids, usize::MAX)?;
|
||||||
|
|
||||||
|
let sequence_max_indices = input_ids.argmax(D::Minus1)?.to_dtype(DType::I64)?;
|
||||||
|
|
||||||
|
let mut indices: Vec<Tensor> = Vec::new();
|
||||||
|
|
||||||
|
for (batch_idx, &seq_idx) in sequence_max_indices.to_vec1::<i64>()?.iter().enumerate() {
|
||||||
|
let index = output.i((batch_idx, seq_idx as usize))?.unsqueeze(0)?;
|
||||||
|
indices.push(index);
|
||||||
|
}
|
||||||
|
|
||||||
|
let pooled_output = Tensor::cat(&indices, 0)?;
|
||||||
|
|
||||||
|
Ok(pooled_output)
|
||||||
|
}
|
||||||
|
}
|
171
candle-transformers/src/models/clip/vision_model.rs
Normal file
171
candle-transformers/src/models/clip/vision_model.rs
Normal file
@ -0,0 +1,171 @@
|
|||||||
|
//! Contrastive Language-Image Pre-Training
|
||||||
|
//!
|
||||||
|
//! Contrastive Language-Image Pre-Training (CLIP) is an architecture trained on
|
||||||
|
//! pairs of images with related texts.
|
||||||
|
//!
|
||||||
|
//! https://github.com/openai/CLIP
|
||||||
|
//! https://github.com/huggingface/transformers/tree/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip
|
||||||
|
|
||||||
|
use candle::{IndexOp, Result, Shape, Tensor, D};
|
||||||
|
use candle_nn as nn;
|
||||||
|
use candle_nn::Module;
|
||||||
|
use nn::Conv2dConfig;
|
||||||
|
use tracing::warn;
|
||||||
|
|
||||||
|
use super::{
|
||||||
|
text_model::{Activation, ClipEncoder},
|
||||||
|
EncoderConfig,
|
||||||
|
};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct ClipVisionConfig {
|
||||||
|
pub embed_dim: usize,
|
||||||
|
pub activation: Activation,
|
||||||
|
pub intermediate_size: usize,
|
||||||
|
pub num_hidden_layers: usize,
|
||||||
|
pub num_attention_heads: usize,
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub projection_dim: usize,
|
||||||
|
pub num_channels: usize,
|
||||||
|
pub image_size: usize,
|
||||||
|
pub patch_size: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ClipVisionConfig {
|
||||||
|
// The config details can be found in the "vision_config" section of this json file:
|
||||||
|
// https://huggingface.co/openai/clip-vit-large-patch14/blob/main/config.json
|
||||||
|
pub fn vit_base_patch32() -> Self {
|
||||||
|
Self {
|
||||||
|
embed_dim: 768,
|
||||||
|
activation: Activation::QuickGelu,
|
||||||
|
intermediate_size: 3072,
|
||||||
|
num_hidden_layers: 12,
|
||||||
|
num_attention_heads: 12,
|
||||||
|
projection_dim: 512,
|
||||||
|
num_channels: 3,
|
||||||
|
image_size: 224,
|
||||||
|
patch_size: 32,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://github.com/huggingface/transformers/blob/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip/modeling_clip.py#L112
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct ClipVisionEmbeddings {
|
||||||
|
patch_embedding: candle_nn::Conv2d,
|
||||||
|
position_ids: Tensor,
|
||||||
|
class_embedding: Tensor,
|
||||||
|
position_embedding: candle_nn::Embedding,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ClipVisionEmbeddings {
|
||||||
|
fn new(vs: candle_nn::VarBuilder, c: &ClipVisionConfig) -> Result<Self> {
|
||||||
|
// originally nn.Parameter
|
||||||
|
let class_embedding = if vs.contains_tensor("class_embedding") {
|
||||||
|
vs.get(c.embed_dim, "class_embedding")?
|
||||||
|
} else {
|
||||||
|
warn!("class_embedding not found in the. Initializing a new one.");
|
||||||
|
Tensor::randn(0.0 as f32, 1.0 as f32, &[c.embed_dim], vs.device())?
|
||||||
|
};
|
||||||
|
|
||||||
|
let num_patches = (c.image_size / c.patch_size).pow(2);
|
||||||
|
|
||||||
|
let num_positions = num_patches + 1;
|
||||||
|
|
||||||
|
let position_ids = Tensor::arange(0, num_positions as i64, vs.device())?;
|
||||||
|
|
||||||
|
let conv2dconfig = Conv2dConfig {
|
||||||
|
stride: c.patch_size,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let position_embedding =
|
||||||
|
candle_nn::embedding(num_positions, c.embed_dim, vs.pp("position_embedding"))?;
|
||||||
|
|
||||||
|
let patch_embedding = candle_nn::conv2d_no_bias(
|
||||||
|
c.num_channels,
|
||||||
|
c.embed_dim,
|
||||||
|
c.patch_size,
|
||||||
|
conv2dconfig,
|
||||||
|
vs.pp("patch_embedding"),
|
||||||
|
)?;
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
patch_embedding,
|
||||||
|
position_ids,
|
||||||
|
class_embedding,
|
||||||
|
position_embedding,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for ClipVisionEmbeddings {
|
||||||
|
fn forward(&self, pixel_values: &Tensor) -> Result<Tensor> {
|
||||||
|
let batch_size = pixel_values.shape().dims();
|
||||||
|
|
||||||
|
let patch_embeds = self.patch_embedding.forward(&pixel_values)?;
|
||||||
|
|
||||||
|
let patch_embeds = patch_embeds.flatten_from(2)?;
|
||||||
|
|
||||||
|
let patch_embeds = patch_embeds.transpose(1, 2)?;
|
||||||
|
|
||||||
|
let class_embedding = self.class_embedding.clone();
|
||||||
|
|
||||||
|
let shape = Shape::from(vec![batch_size[0], 1, class_embedding.dim(D::Minus1)?]);
|
||||||
|
|
||||||
|
let class_embeds = class_embedding.expand(shape)?;
|
||||||
|
|
||||||
|
let embeddings = Tensor::cat(&[class_embeds, patch_embeds], 1)?;
|
||||||
|
|
||||||
|
let position_embedding = self.position_embedding.forward(&self.position_ids)?;
|
||||||
|
|
||||||
|
let embeddings = embeddings.broadcast_add(&position_embedding)?;
|
||||||
|
|
||||||
|
Ok(embeddings)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://github.com/huggingface/transformers/blob/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip/modeling_clip.py#L743
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct ClipVisionTransformer {
|
||||||
|
embeddings: ClipVisionEmbeddings,
|
||||||
|
encoder: ClipEncoder,
|
||||||
|
pre_layer_norm: candle_nn::LayerNorm,
|
||||||
|
final_layer_norm: candle_nn::LayerNorm,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ClipVisionTransformer {
|
||||||
|
pub fn new(vs: candle_nn::VarBuilder, c: &ClipVisionConfig) -> Result<Self> {
|
||||||
|
let embeddings = ClipVisionEmbeddings::new(vs.pp("embeddings"), c)?;
|
||||||
|
|
||||||
|
let pre_layer_norm = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("pre_layrnorm"))?;
|
||||||
|
|
||||||
|
let encoder = ClipEncoder::new(vs.pp("encoder"), &EncoderConfig::Vision(c.clone()))?;
|
||||||
|
|
||||||
|
let final_layer_norm = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("post_layernorm"))?;
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
embeddings,
|
||||||
|
encoder,
|
||||||
|
final_layer_norm,
|
||||||
|
pre_layer_norm,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for ClipVisionTransformer {
|
||||||
|
fn forward(&self, pixel_values: &Tensor) -> Result<Tensor> {
|
||||||
|
let hidden_states = self.embeddings.forward(pixel_values)?;
|
||||||
|
|
||||||
|
let hidden_states = self.pre_layer_norm.forward(&hidden_states)?;
|
||||||
|
|
||||||
|
let encoder_outputs = self.encoder.forward(&hidden_states, None)?;
|
||||||
|
|
||||||
|
// https://github.com/huggingface/transformers/blob/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip/modeling_clip.py#L787
|
||||||
|
// pooled_output = encoder_outputs[:, 0, :]
|
||||||
|
let pooled_output = encoder_outputs.i((.., 0, ..))?;
|
||||||
|
|
||||||
|
let output = self.final_layer_norm.forward(&pooled_output)?;
|
||||||
|
|
||||||
|
Ok(output)
|
||||||
|
}
|
||||||
|
}
|
@ -12,6 +12,7 @@ pub mod efficientvit;
|
|||||||
pub mod encodec;
|
pub mod encodec;
|
||||||
pub mod falcon;
|
pub mod falcon;
|
||||||
pub mod gemma;
|
pub mod gemma;
|
||||||
|
pub mod clip;
|
||||||
pub mod jina_bert;
|
pub mod jina_bert;
|
||||||
pub mod llama;
|
pub mod llama;
|
||||||
pub mod llama2_c;
|
pub mod llama2_c;
|
||||||
|
Reference in New Issue
Block a user