mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
MobileCLIP models S1 and S2 (#2454)
* Allow loading images with given std and mean * OpenCLIP text encoder component * Two MobileCLIP models * Clippy fixes. --------- Co-authored-by: Laurent <laurent.mazare@gmail.com>
This commit is contained in:
28
candle-examples/examples/mobileclip/README.md
Normal file
28
candle-examples/examples/mobileclip/README.md
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
# candle-mobileclip
|
||||||
|
|
||||||
|
MobileCLIP is family of efficient CLIP-like models using FastViT-based image encoders.
|
||||||
|
|
||||||
|
See [MobileCLIP: Fast Image-Text Models through Multi-Modal Reinforced Training](https://arxiv.org/abs/2311.17049)
|
||||||
|
|
||||||
|
|
||||||
|
## Running on an example on cpu
|
||||||
|
|
||||||
|
```
|
||||||
|
$ cargo run --example mobileclip --release -- --images "candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg","candle-examples/examples/yolo-v8/assets/bike.jpg" --cpu --sequences "a cycling race","a photo of two cats","a robot holding a candle"
|
||||||
|
|
||||||
|
softmax_image_vec: [2.4819004e-5, 3.81081e-6, 0.9999714, 0.9999738, 2.382714e-5, 2.3317718e-6]
|
||||||
|
|
||||||
|
|
||||||
|
Results for image: candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg
|
||||||
|
|
||||||
|
Probability: 0.0025% Text: a cycling race
|
||||||
|
Probability: 0.0004% Text: a photo of two cats
|
||||||
|
Probability: 99.9971% Text: a robot holding a candle
|
||||||
|
|
||||||
|
|
||||||
|
Results for image: candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||||
|
|
||||||
|
Probability: 99.9974% Text: a cycling race
|
||||||
|
Probability: 0.0024% Text: a photo of two cats
|
||||||
|
Probability: 0.0002% Text: a robot holding a candle
|
||||||
|
```
|
192
candle-examples/examples/mobileclip/main.rs
Normal file
192
candle-examples/examples/mobileclip/main.rs
Normal file
@ -0,0 +1,192 @@
|
|||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
|
||||||
|
use anyhow::Error as E;
|
||||||
|
use clap::{Parser, ValueEnum};
|
||||||
|
|
||||||
|
use candle::{DType, Device, Tensor};
|
||||||
|
use candle_nn::{ops::softmax, VarBuilder};
|
||||||
|
use candle_transformers::models::mobileclip;
|
||||||
|
|
||||||
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||||
|
enum Which {
|
||||||
|
S1,
|
||||||
|
S2,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Which {
|
||||||
|
fn model_name(&self) -> String {
|
||||||
|
let name = match self {
|
||||||
|
Self::S1 => "S1",
|
||||||
|
Self::S2 => "S2",
|
||||||
|
};
|
||||||
|
format!("apple/MobileCLIP-{}-OpenCLIP", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn config(&self) -> mobileclip::MobileClipConfig {
|
||||||
|
match self {
|
||||||
|
Self::S1 => mobileclip::MobileClipConfig::s1(),
|
||||||
|
Self::S2 => mobileclip::MobileClipConfig::s2(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Parser)]
|
||||||
|
struct Args {
|
||||||
|
#[arg(long, use_value_delimiter = true)]
|
||||||
|
images: Option<Vec<String>>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
|
||||||
|
/// Use the pytorch weights rather than the safetensors ones
|
||||||
|
#[arg(long)]
|
||||||
|
use_pth: bool,
|
||||||
|
|
||||||
|
#[arg(long, use_value_delimiter = true)]
|
||||||
|
sequences: Option<Vec<String>>,
|
||||||
|
|
||||||
|
#[arg(value_enum, long, default_value_t=Which::S1)]
|
||||||
|
which: Which,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn load_images<T: AsRef<std::path::Path>>(
|
||||||
|
paths: &Vec<T>,
|
||||||
|
image_size: usize,
|
||||||
|
) -> anyhow::Result<Tensor> {
|
||||||
|
let mut images = vec![];
|
||||||
|
|
||||||
|
for path in paths {
|
||||||
|
let tensor = candle_examples::imagenet::load_image_with_std_mean(
|
||||||
|
path,
|
||||||
|
image_size,
|
||||||
|
&[0.0, 0.0, 0.0],
|
||||||
|
&[1.0, 1.0, 1.0],
|
||||||
|
)?;
|
||||||
|
images.push(tensor);
|
||||||
|
}
|
||||||
|
|
||||||
|
let images = Tensor::stack(&images, 0)?;
|
||||||
|
|
||||||
|
Ok(images)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn main() -> anyhow::Result<()> {
|
||||||
|
let args = Args::parse();
|
||||||
|
|
||||||
|
let model_name = args.which.model_name();
|
||||||
|
|
||||||
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
|
let api = api.model(model_name);
|
||||||
|
|
||||||
|
let model_file = if args.use_pth {
|
||||||
|
api.get("open_clip_pytorch_model.bin")?
|
||||||
|
} else {
|
||||||
|
api.get("open_clip_model.safetensors")?
|
||||||
|
};
|
||||||
|
|
||||||
|
let tokenizer = api.get("tokenizer.json")?;
|
||||||
|
|
||||||
|
let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
|
||||||
|
|
||||||
|
let config = &args.which.config();
|
||||||
|
|
||||||
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
|
||||||
|
let vec_imgs = match args.images {
|
||||||
|
Some(imgs) => imgs,
|
||||||
|
None => vec![
|
||||||
|
"candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg".to_string(),
|
||||||
|
"candle-examples/examples/yolo-v8/assets/bike.jpg".to_string(),
|
||||||
|
],
|
||||||
|
};
|
||||||
|
|
||||||
|
let images = load_images(&vec_imgs, config.image_size)?.to_device(&device)?;
|
||||||
|
|
||||||
|
let vb = if args.use_pth {
|
||||||
|
VarBuilder::from_pth(&model_file, DType::F32, &device)?
|
||||||
|
} else {
|
||||||
|
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file.clone()], DType::F32, &device)? }
|
||||||
|
};
|
||||||
|
|
||||||
|
let model = mobileclip::MobileClipModel::new(vb, config)?;
|
||||||
|
|
||||||
|
let (input_ids, vec_seq) = tokenize_sequences(args.sequences, &tokenizer, &device)?;
|
||||||
|
|
||||||
|
let (_logits_per_text, logits_per_image) = model.forward(&images, &input_ids)?;
|
||||||
|
|
||||||
|
let softmax_image = softmax(&logits_per_image, 1)?;
|
||||||
|
|
||||||
|
let softmax_image_vec = softmax_image.flatten_all()?.to_vec1::<f32>()?;
|
||||||
|
|
||||||
|
println!("softmax_image_vec: {:?}", softmax_image_vec);
|
||||||
|
|
||||||
|
let probability_vec = softmax_image_vec
|
||||||
|
.iter()
|
||||||
|
.map(|v| v * 100.0)
|
||||||
|
.collect::<Vec<f32>>();
|
||||||
|
|
||||||
|
let probability_per_image = probability_vec.len() / vec_imgs.len();
|
||||||
|
|
||||||
|
for (i, img) in vec_imgs.iter().enumerate() {
|
||||||
|
let start = i * probability_per_image;
|
||||||
|
let end = start + probability_per_image;
|
||||||
|
let prob = &probability_vec[start..end];
|
||||||
|
println!("\n\nResults for image: {}\n", img);
|
||||||
|
|
||||||
|
for (i, p) in prob.iter().enumerate() {
|
||||||
|
println!("Probability: {:.4}% Text: {}", p, vec_seq[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn tokenize_sequences(
|
||||||
|
sequences: Option<Vec<String>>,
|
||||||
|
tokenizer: &Tokenizer,
|
||||||
|
device: &Device,
|
||||||
|
) -> anyhow::Result<(Tensor, Vec<String>)> {
|
||||||
|
// let pad_id = *tokenizer
|
||||||
|
// .get_vocab(true)
|
||||||
|
// .get("<|endoftext|>")
|
||||||
|
// .ok_or(E::msg("No pad token"))?;
|
||||||
|
|
||||||
|
// The model does not work well if the text is padded using the <|endoftext|> token, using 0
|
||||||
|
// as the original OpenCLIP code.
|
||||||
|
let pad_id = 0;
|
||||||
|
|
||||||
|
let vec_seq = match sequences {
|
||||||
|
Some(seq) => seq,
|
||||||
|
None => vec![
|
||||||
|
"a cycling race".to_string(),
|
||||||
|
"a photo of two cats".to_string(),
|
||||||
|
"a robot holding a candle".to_string(),
|
||||||
|
],
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut tokens = vec![];
|
||||||
|
|
||||||
|
for seq in vec_seq.clone() {
|
||||||
|
let encoding = tokenizer.encode(seq, true).map_err(E::msg)?;
|
||||||
|
tokens.push(encoding.get_ids().to_vec());
|
||||||
|
}
|
||||||
|
|
||||||
|
let max_len = tokens.iter().map(|v| v.len()).max().unwrap_or(0);
|
||||||
|
// Pad the sequences to have the same length
|
||||||
|
for token_vec in tokens.iter_mut() {
|
||||||
|
let len_diff = max_len - token_vec.len();
|
||||||
|
if len_diff > 0 {
|
||||||
|
token_vec.extend(vec![pad_id; len_diff]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let input_ids = Tensor::new(tokens, device)?;
|
||||||
|
|
||||||
|
Ok((input_ids, vec_seq))
|
||||||
|
}
|
@ -72,8 +72,9 @@ pub fn main() -> anyhow::Result<()> {
|
|||||||
|
|
||||||
let device = candle_examples::device(args.cpu)?;
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
|
||||||
let image = candle_examples::imagenet::load_image(args.image, args.which.resolution())?
|
let image =
|
||||||
.to_device(&device)?;
|
candle_examples::imagenet::load_image(args.image, args.which.resolution() as usize)?
|
||||||
|
.to_device(&device)?;
|
||||||
println!("loaded image {image:?}");
|
println!("loaded image {image:?}");
|
||||||
|
|
||||||
let model_file = match args.model {
|
let model_file = match args.model {
|
||||||
|
@ -1,23 +1,42 @@
|
|||||||
use candle::{Device, Result, Tensor};
|
use candle::{Device, Result, Tensor};
|
||||||
|
|
||||||
/// Loads an image from disk using the image crate at the requested resolution.
|
pub const IMAGENET_MEAN: [f32; 3] = [0.485f32, 0.456, 0.406];
|
||||||
// This returns a tensor with shape (3, res, res). imagenet normalization is applied.
|
pub const IMAGENET_STD: [f32; 3] = [0.229f32, 0.224, 0.225];
|
||||||
pub fn load_image<P: AsRef<std::path::Path>>(p: P, res: u32) -> Result<Tensor> {
|
|
||||||
|
/// Loads an image from disk using the image crate at the requested resolution,
|
||||||
|
/// using the given std and mean parameters.
|
||||||
|
/// This returns a tensor with shape (3, res, res). imagenet normalization is applied.
|
||||||
|
|
||||||
|
pub fn load_image_with_std_mean<P: AsRef<std::path::Path>>(
|
||||||
|
p: P,
|
||||||
|
res: usize,
|
||||||
|
mean: &[f32; 3],
|
||||||
|
std: &[f32; 3],
|
||||||
|
) -> Result<Tensor> {
|
||||||
let img = image::ImageReader::open(p)?
|
let img = image::ImageReader::open(p)?
|
||||||
.decode()
|
.decode()
|
||||||
.map_err(candle::Error::wrap)?
|
.map_err(candle::Error::wrap)?
|
||||||
.resize_to_fill(res, res, image::imageops::FilterType::Triangle);
|
.resize_to_fill(
|
||||||
|
res as u32,
|
||||||
|
res as u32,
|
||||||
|
image::imageops::FilterType::Triangle,
|
||||||
|
);
|
||||||
let img = img.to_rgb8();
|
let img = img.to_rgb8();
|
||||||
let data = img.into_raw();
|
let data = img.into_raw();
|
||||||
let data = Tensor::from_vec(data, (res as usize, res as usize, 3), &Device::Cpu)?
|
let data = Tensor::from_vec(data, (res, res, 3), &Device::Cpu)?.permute((2, 0, 1))?;
|
||||||
.permute((2, 0, 1))?;
|
let mean = Tensor::new(mean, &Device::Cpu)?.reshape((3, 1, 1))?;
|
||||||
let mean = Tensor::new(&[0.485f32, 0.456, 0.406], &Device::Cpu)?.reshape((3, 1, 1))?;
|
let std = Tensor::new(std, &Device::Cpu)?.reshape((3, 1, 1))?;
|
||||||
let std = Tensor::new(&[0.229f32, 0.224, 0.225], &Device::Cpu)?.reshape((3, 1, 1))?;
|
|
||||||
(data.to_dtype(candle::DType::F32)? / 255.)?
|
(data.to_dtype(candle::DType::F32)? / 255.)?
|
||||||
.broadcast_sub(&mean)?
|
.broadcast_sub(&mean)?
|
||||||
.broadcast_div(&std)
|
.broadcast_div(&std)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Loads an image from disk using the image crate at the requested resolution.
|
||||||
|
/// This returns a tensor with shape (3, res, res). imagenet normalization is applied.
|
||||||
|
pub fn load_image<P: AsRef<std::path::Path>>(p: P, res: usize) -> Result<Tensor> {
|
||||||
|
load_image_with_std_mean(p, res, &IMAGENET_MEAN, &IMAGENET_STD)
|
||||||
|
}
|
||||||
|
|
||||||
/// Loads an image from disk using the image crate, this returns a tensor with shape
|
/// Loads an image from disk using the image crate, this returns a tensor with shape
|
||||||
/// (3, 224, 224). imagenet normalization is applied.
|
/// (3, 224, 224). imagenet normalization is applied.
|
||||||
pub fn load_image224<P: AsRef<std::path::Path>>(p: P) -> Result<Tensor> {
|
pub fn load_image224<P: AsRef<std::path::Path>>(p: P) -> Result<Tensor> {
|
||||||
|
89
candle-transformers/src/models/mobileclip.rs
Normal file
89
candle-transformers/src/models/mobileclip.rs
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
use super::fastvit;
|
||||||
|
use super::openclip::text_model;
|
||||||
|
use candle::{Result, Tensor, D};
|
||||||
|
use candle_nn::{Func, VarBuilder};
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct MobileClipModel {
|
||||||
|
text_model: text_model::OpenClipTextTransformer,
|
||||||
|
vision_model: Func<'static>,
|
||||||
|
text_projection: Tensor,
|
||||||
|
logit_scale: Tensor,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct MobileClipConfig {
|
||||||
|
pub text_config: text_model::Config,
|
||||||
|
pub vision_config: fastvit::Config,
|
||||||
|
pub image_size: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MobileClipConfig {
|
||||||
|
pub fn s1() -> Self {
|
||||||
|
let text_config = text_model::Config::vit_base_patch32();
|
||||||
|
let vision_config = fastvit::Config::mci1();
|
||||||
|
|
||||||
|
Self {
|
||||||
|
text_config,
|
||||||
|
vision_config,
|
||||||
|
image_size: 256,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pub fn s2() -> Self {
|
||||||
|
let text_config = text_model::Config::vit_base_patch32();
|
||||||
|
let vision_config = fastvit::Config::mci2();
|
||||||
|
|
||||||
|
Self {
|
||||||
|
text_config,
|
||||||
|
vision_config,
|
||||||
|
image_size: 256,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MobileClipModel {
|
||||||
|
pub fn new(vs: VarBuilder, c: &MobileClipConfig) -> Result<Self> {
|
||||||
|
let vision_model = fastvit::fastvit(&c.vision_config, 512, vs.pp("visual.trunk"))?;
|
||||||
|
let text_model = text_model::OpenClipTextTransformer::new(vs.pp("text"), &c.text_config)?;
|
||||||
|
|
||||||
|
let text_projection = vs.get(
|
||||||
|
(c.text_config.embed_dim, c.text_config.projection_dim),
|
||||||
|
"text.text_projection",
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let logit_scale = vs.get(&[], "logit_scale")?;
|
||||||
|
Ok(Self {
|
||||||
|
text_model,
|
||||||
|
vision_model,
|
||||||
|
text_projection,
|
||||||
|
logit_scale,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_text_features(&self, input_ids: &Tensor) -> Result<Tensor> {
|
||||||
|
input_ids
|
||||||
|
.apply(&self.text_model)?
|
||||||
|
.matmul(&self.text_projection)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_image_features(&self, pixel_values: &Tensor) -> Result<Tensor> {
|
||||||
|
pixel_values.apply(&self.vision_model)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(&self, pixel_values: &Tensor, input_ids: &Tensor) -> Result<(Tensor, Tensor)> {
|
||||||
|
let image_features = self.get_image_features(pixel_values)?;
|
||||||
|
let text_features = self.get_text_features(input_ids)?;
|
||||||
|
let image_features_normalized = div_l2_norm(&image_features)?;
|
||||||
|
let text_features_normalized = div_l2_norm(&text_features)?;
|
||||||
|
let logits_per_text = text_features_normalized.matmul(&image_features_normalized.t()?)?;
|
||||||
|
let logit_scale = self.logit_scale.exp()?;
|
||||||
|
let logits_per_text = logits_per_text.broadcast_mul(&logit_scale)?;
|
||||||
|
let logits_per_image = logits_per_text.t()?;
|
||||||
|
Ok((logits_per_text, logits_per_image))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn div_l2_norm(v: &Tensor) -> Result<Tensor> {
|
||||||
|
let l2_norm = v.sqr()?.sum_keepdim(D::Minus1)?.sqrt()?;
|
||||||
|
v.broadcast_div(&l2_norm)
|
||||||
|
}
|
@ -37,11 +37,13 @@ pub mod mistral;
|
|||||||
pub mod mixformer;
|
pub mod mixformer;
|
||||||
pub mod mixtral;
|
pub mod mixtral;
|
||||||
pub mod mmdit;
|
pub mod mmdit;
|
||||||
|
pub mod mobileclip;
|
||||||
pub mod mobilenetv4;
|
pub mod mobilenetv4;
|
||||||
pub mod mobileone;
|
pub mod mobileone;
|
||||||
pub mod moondream;
|
pub mod moondream;
|
||||||
pub mod mpt;
|
pub mod mpt;
|
||||||
pub mod olmo;
|
pub mod olmo;
|
||||||
|
pub mod openclip;
|
||||||
pub mod parler_tts;
|
pub mod parler_tts;
|
||||||
pub mod persimmon;
|
pub mod persimmon;
|
||||||
pub mod phi;
|
pub mod phi;
|
||||||
|
1
candle-transformers/src/models/openclip/mod.rs
Normal file
1
candle-transformers/src/models/openclip/mod.rs
Normal file
@ -0,0 +1 @@
|
|||||||
|
pub mod text_model;
|
266
candle-transformers/src/models/openclip/text_model.rs
Normal file
266
candle-transformers/src/models/openclip/text_model.rs
Normal file
@ -0,0 +1,266 @@
|
|||||||
|
//! Text encoder as used in most OpenCLIP pretrained models
|
||||||
|
//! https://github.com/mlfoundations/open_clip
|
||||||
|
|
||||||
|
use candle::{DType, IndexOp, Result, Tensor, D};
|
||||||
|
use candle_nn::{
|
||||||
|
embedding, layer_norm, linear, ops::softmax_last_dim, Embedding, LayerNorm, Linear, Module,
|
||||||
|
VarBuilder,
|
||||||
|
};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct Config {
|
||||||
|
pub vocab_size: usize,
|
||||||
|
pub embed_dim: usize,
|
||||||
|
pub intermediate_size: usize,
|
||||||
|
pub max_position_embeddings: usize,
|
||||||
|
pub pad_with: Option<String>,
|
||||||
|
pub num_hidden_layers: usize,
|
||||||
|
pub num_attention_heads: usize,
|
||||||
|
pub projection_dim: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Config {
|
||||||
|
pub fn vit_base_patch32() -> Self {
|
||||||
|
Self {
|
||||||
|
vocab_size: 49408,
|
||||||
|
embed_dim: 512,
|
||||||
|
intermediate_size: 2048,
|
||||||
|
max_position_embeddings: 77,
|
||||||
|
pad_with: None,
|
||||||
|
num_hidden_layers: 12,
|
||||||
|
num_attention_heads: 8,
|
||||||
|
projection_dim: 512,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
struct TextEmbeddings {
|
||||||
|
token_embedding: Embedding,
|
||||||
|
position_embedding: Tensor,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TextEmbeddings {
|
||||||
|
fn new(vs: VarBuilder, c: &Config) -> Result<Self> {
|
||||||
|
let token_embedding = embedding(c.vocab_size, c.embed_dim, vs.pp("token_embedding"))?;
|
||||||
|
let position_embedding = vs.get(
|
||||||
|
(c.max_position_embeddings, c.embed_dim),
|
||||||
|
"positional_embedding",
|
||||||
|
)?;
|
||||||
|
Ok(TextEmbeddings {
|
||||||
|
token_embedding,
|
||||||
|
position_embedding,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for TextEmbeddings {
|
||||||
|
fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
|
||||||
|
let seq_length = input_ids.dim(D::Minus1)?;
|
||||||
|
let inputs_embeds = self.token_embedding.forward(input_ids)?;
|
||||||
|
|
||||||
|
let position_embedding = self.position_embedding.narrow(0, 0, seq_length)?;
|
||||||
|
|
||||||
|
inputs_embeds.broadcast_add(&position_embedding)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
struct Attention {
|
||||||
|
k_proj: candle_nn::Linear,
|
||||||
|
v_proj: candle_nn::Linear,
|
||||||
|
q_proj: candle_nn::Linear,
|
||||||
|
out_proj: Linear,
|
||||||
|
head_dim: usize,
|
||||||
|
scale: f64,
|
||||||
|
num_attention_heads: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Attention {
|
||||||
|
fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> {
|
||||||
|
let embed_dim = c.embed_dim;
|
||||||
|
let num_attention_heads = c.num_attention_heads;
|
||||||
|
|
||||||
|
let in_proj_weights = vs
|
||||||
|
.get((embed_dim * 3, embed_dim), "in_proj_weight")?
|
||||||
|
.chunk(3, 0)?;
|
||||||
|
let (q_w, k_w, v_w) = (
|
||||||
|
&in_proj_weights[0],
|
||||||
|
&in_proj_weights[1],
|
||||||
|
&in_proj_weights[2],
|
||||||
|
);
|
||||||
|
let in_proj_biases = vs.get(embed_dim * 3, "in_proj_bias")?.chunk(3, 0)?;
|
||||||
|
let (q_b, k_b, v_b) = (&in_proj_biases[0], &in_proj_biases[1], &in_proj_biases[2]);
|
||||||
|
|
||||||
|
let q_proj = Linear::new(q_w.clone(), Some(q_b.clone()));
|
||||||
|
let k_proj = Linear::new(k_w.clone(), Some(k_b.clone()));
|
||||||
|
let v_proj = Linear::new(v_w.clone(), Some(v_b.clone()));
|
||||||
|
let out_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("out_proj"))?;
|
||||||
|
let head_dim = embed_dim / num_attention_heads;
|
||||||
|
let scale = (head_dim as f64).powf(-0.5);
|
||||||
|
|
||||||
|
Ok(Attention {
|
||||||
|
k_proj,
|
||||||
|
v_proj,
|
||||||
|
q_proj,
|
||||||
|
out_proj,
|
||||||
|
head_dim,
|
||||||
|
scale,
|
||||||
|
num_attention_heads,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn shape_multihead(&self, xs: &Tensor, bsz: usize, seq_len: usize) -> Result<Tensor> {
|
||||||
|
xs.reshape((bsz, seq_len, self.num_attention_heads, self.head_dim))?
|
||||||
|
.transpose(1, 2)?
|
||||||
|
.contiguous()?
|
||||||
|
.to_dtype(DType::F32)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let in_dtype = xs.dtype();
|
||||||
|
let (bsz, seq_len, embed_dim) = xs.dims3()?;
|
||||||
|
|
||||||
|
let q = self.shape_multihead(&self.q_proj.forward(xs)?, bsz, seq_len)?;
|
||||||
|
let k = self.shape_multihead(&self.k_proj.forward(xs)?, bsz, seq_len)?;
|
||||||
|
let v = self.shape_multihead(&self.v_proj.forward(xs)?, bsz, seq_len)?;
|
||||||
|
let q = (q * self.scale)?;
|
||||||
|
|
||||||
|
let attn_weights = q.matmul(&k.transpose(D::Minus1, D::Minus2)?)?;
|
||||||
|
|
||||||
|
let attn_weights = softmax_last_dim(&attn_weights)?;
|
||||||
|
|
||||||
|
let attn_output = attn_weights.matmul(&v)?.to_dtype(in_dtype)?;
|
||||||
|
let attn_output = attn_output
|
||||||
|
.transpose(1, 2)?
|
||||||
|
.contiguous()?
|
||||||
|
.reshape((bsz, seq_len, embed_dim))?;
|
||||||
|
let out = self.out_proj.forward(&attn_output)?;
|
||||||
|
Ok(out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
struct Mlp {
|
||||||
|
fc1: Linear,
|
||||||
|
fc2: Linear,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Mlp {
|
||||||
|
fn new(vs: VarBuilder, c: &Config) -> Result<Self> {
|
||||||
|
let fc1 = linear(c.embed_dim, c.intermediate_size, vs.pp("c_fc"))?;
|
||||||
|
let fc2 = linear(c.intermediate_size, c.embed_dim, vs.pp("c_proj"))?;
|
||||||
|
|
||||||
|
Ok(Mlp { fc1, fc2 })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Mlp {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let xs = self.fc1.forward(xs)?;
|
||||||
|
self.fc2.forward(&xs.gelu_erf()?)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
struct EncoderLayer {
|
||||||
|
self_attn: Attention,
|
||||||
|
layer_norm1: LayerNorm,
|
||||||
|
mlp: Mlp,
|
||||||
|
layer_norm2: LayerNorm,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl EncoderLayer {
|
||||||
|
fn new(vs: VarBuilder, c: &Config) -> Result<Self> {
|
||||||
|
let self_attn = Attention::new(vs.pp("attn"), c)?;
|
||||||
|
let layer_norm1 = layer_norm(c.embed_dim, 1e-5, vs.pp("ln_1"))?;
|
||||||
|
let mlp = Mlp::new(vs.pp("mlp"), c)?;
|
||||||
|
let layer_norm2 = layer_norm(c.embed_dim, 1e-5, vs.pp("ln_2"))?;
|
||||||
|
|
||||||
|
Ok(EncoderLayer {
|
||||||
|
self_attn,
|
||||||
|
layer_norm1,
|
||||||
|
mlp,
|
||||||
|
layer_norm2,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let residual = xs;
|
||||||
|
let xs = self.layer_norm1.forward(xs)?;
|
||||||
|
let xs = self.self_attn.forward(&xs)?;
|
||||||
|
let xs = (xs + residual)?;
|
||||||
|
|
||||||
|
let residual = &xs;
|
||||||
|
let xs = self.layer_norm2.forward(&xs)?;
|
||||||
|
let xs = self.mlp.forward(&xs)?;
|
||||||
|
let out = (xs + residual)?;
|
||||||
|
Ok(out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct Encoder {
|
||||||
|
layers: Vec<EncoderLayer>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Encoder {
|
||||||
|
pub fn new(vs: VarBuilder, c: &Config) -> Result<Self> {
|
||||||
|
let vs = vs.pp("resblocks");
|
||||||
|
let mut layers: Vec<EncoderLayer> = Vec::new();
|
||||||
|
for index in 0..c.num_hidden_layers {
|
||||||
|
let layer = EncoderLayer::new(vs.pp(index.to_string()), c)?;
|
||||||
|
layers.push(layer)
|
||||||
|
}
|
||||||
|
Ok(Encoder { layers })
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let mut xs = xs.clone();
|
||||||
|
for layer in self.layers.iter() {
|
||||||
|
xs = layer.forward(&xs)?;
|
||||||
|
}
|
||||||
|
Ok(xs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A text transformer as used in CLIP variants.
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct OpenClipTextTransformer {
|
||||||
|
embeddings: TextEmbeddings,
|
||||||
|
encoder: Encoder,
|
||||||
|
final_layer_norm: LayerNorm,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl OpenClipTextTransformer {
|
||||||
|
pub fn new(vs: VarBuilder, c: &Config) -> Result<Self> {
|
||||||
|
let embeddings = TextEmbeddings::new(vs.clone(), c)?;
|
||||||
|
let final_layer_norm = layer_norm(c.embed_dim, 1e-5, vs.pp("ln_final"))?;
|
||||||
|
let encoder = Encoder::new(vs.pp("transformer"), c)?;
|
||||||
|
Ok(OpenClipTextTransformer {
|
||||||
|
embeddings,
|
||||||
|
encoder,
|
||||||
|
final_layer_norm,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
|
||||||
|
let input_ids = self.embeddings.forward(input_ids)?;
|
||||||
|
let input_ids = self.encoder.forward(&input_ids)?;
|
||||||
|
self.final_layer_norm.forward(&input_ids)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for OpenClipTextTransformer {
|
||||||
|
fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
|
||||||
|
let output = self.forward(input_ids)?;
|
||||||
|
let sequence_max_indices = input_ids.argmax(D::Minus1)?.to_dtype(DType::I64)?;
|
||||||
|
|
||||||
|
let mut indices = Vec::new();
|
||||||
|
for (batch_idx, &seq_idx) in sequence_max_indices.to_vec1::<i64>()?.iter().enumerate() {
|
||||||
|
let index = output.i((batch_idx, seq_idx as usize))?.unsqueeze(0)?;
|
||||||
|
indices.push(index);
|
||||||
|
}
|
||||||
|
Tensor::cat(&indices, 0)
|
||||||
|
}
|
||||||
|
}
|
Reference in New Issue
Block a user