Add the SigLIP model. (#2515)

* Add the SigLIP model.

* Add more to the forward pass of the vision model.

* Complete the forward pass.

* Add the siglip example.

* Fix.

* Another fix.

* Get everything in place.

* Add a readme.
This commit is contained in:
Laurent Mazare
2024-09-28 23:48:00 +02:00
committed by GitHub
parent 62525e8352
commit 261ed65f36
8 changed files with 797 additions and 54 deletions

View File

@ -12,7 +12,6 @@ use candle_nn::{ops::softmax, VarBuilder};
use candle_transformers::models::clip; use candle_transformers::models::clip;
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
use tracing::info;
#[derive(Parser)] #[derive(Parser)]
struct Args { struct Args {
@ -40,15 +39,12 @@ fn load_image<T: AsRef<std::path::Path>>(path: T, image_size: usize) -> anyhow::
height as u32, height as u32,
image::imageops::FilterType::Triangle, image::imageops::FilterType::Triangle,
); );
let img = img.to_rgb8(); let img = img.to_rgb8();
let img = img.into_raw(); let img = img.into_raw();
let img = Tensor::from_vec(img, (height, width, 3), &Device::Cpu)? let img = Tensor::from_vec(img, (height, width, 3), &Device::Cpu)?
.permute((2, 0, 1))? .permute((2, 0, 1))?
.to_dtype(DType::F32)? .to_dtype(DType::F32)?
.affine(2. / 255., -1.)?; .affine(2. / 255., -1.)?;
// .unsqueeze(0)?;
Ok(img) Ok(img)
} }
@ -57,24 +53,16 @@ fn load_images<T: AsRef<std::path::Path>>(
image_size: usize, image_size: usize,
) -> anyhow::Result<Tensor> { ) -> anyhow::Result<Tensor> {
let mut images = vec![]; let mut images = vec![];
for path in paths { for path in paths {
let tensor = load_image(path, image_size)?; let tensor = load_image(path, image_size)?;
images.push(tensor); images.push(tensor);
} }
let images = Tensor::stack(&images, 0)?; let images = Tensor::stack(&images, 0)?;
Ok(images) Ok(images)
} }
pub fn main() -> anyhow::Result<()> { pub fn main() -> anyhow::Result<()> {
// std::env::set_var("RUST_BACKTRACE", "full");
let args = Args::parse(); let args = Args::parse();
tracing_subscriber::fmt::init();
let model_file = match args.model { let model_file = match args.model {
None => { None => {
let api = hf_hub::api::sync::Api::new()?; let api = hf_hub::api::sync::Api::new()?;
@ -89,13 +77,9 @@ pub fn main() -> anyhow::Result<()> {
} }
Some(model) => model.into(), Some(model) => model.into(),
}; };
let tokenizer = get_tokenizer(args.tokenizer)?; let tokenizer = get_tokenizer(args.tokenizer)?;
let config = clip::ClipConfig::vit_base_patch32(); let config = clip::ClipConfig::vit_base_patch32();
let device = candle_examples::device(args.cpu)?; let device = candle_examples::device(args.cpu)?;
let vec_imgs = match args.images { let vec_imgs = match args.images {
Some(imgs) => imgs, Some(imgs) => imgs,
None => vec![ None => vec![
@ -103,43 +87,29 @@ pub fn main() -> anyhow::Result<()> {
"candle-examples/examples/yolo-v8/assets/bike.jpg".to_string(), "candle-examples/examples/yolo-v8/assets/bike.jpg".to_string(),
], ],
}; };
// let image = load_image(args.image, config.image_size)?.to_device(&device)?;
let images = load_images(&vec_imgs, config.image_size)?.to_device(&device)?; let images = load_images(&vec_imgs, config.image_size)?.to_device(&device)?;
let vb = let vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file.clone()], DType::F32, &device)? }; unsafe { VarBuilder::from_mmaped_safetensors(&[model_file.clone()], DType::F32, &device)? };
let model = clip::ClipModel::new(vb, &config)?; let model = clip::ClipModel::new(vb, &config)?;
let (input_ids, vec_seq) = tokenize_sequences(args.sequences, &tokenizer, &device)?; let (input_ids, vec_seq) = tokenize_sequences(args.sequences, &tokenizer, &device)?;
let (_logits_per_text, logits_per_image) = model.forward(&images, &input_ids)?; let (_logits_per_text, logits_per_image) = model.forward(&images, &input_ids)?;
let softmax_image = softmax(&logits_per_image, 1)?; let softmax_image = softmax(&logits_per_image, 1)?;
let softmax_image_vec = softmax_image.flatten_all()?.to_vec1::<f32>()?; let softmax_image_vec = softmax_image.flatten_all()?.to_vec1::<f32>()?;
println!("softmax_image_vec: {:?}", softmax_image_vec);
info!("softmax_image_vec: {:?}", softmax_image_vec);
let probability_vec = softmax_image_vec let probability_vec = softmax_image_vec
.iter() .iter()
.map(|v| v * 100.0) .map(|v| v * 100.0)
.collect::<Vec<f32>>(); .collect::<Vec<f32>>();
let probability_per_image = probability_vec.len() / vec_imgs.len(); let probability_per_image = probability_vec.len() / vec_imgs.len();
for (i, img) in vec_imgs.iter().enumerate() { for (i, img) in vec_imgs.iter().enumerate() {
let start = i * probability_per_image; let start = i * probability_per_image;
let end = start + probability_per_image; let end = start + probability_per_image;
let prob = &probability_vec[start..end]; let prob = &probability_vec[start..end];
info!("\n\nResults for image: {}\n", img); println!("\n\nResults for image: {}\n", img);
for (i, p) in prob.iter().enumerate() { for (i, p) in prob.iter().enumerate() {
info!("Probability: {:.4}% Text: {} ", p, vec_seq[i]); println!("Probability: {:.4}% Text: {} ", p, vec_seq[i]);
} }
} }
Ok(()) Ok(())
} }
@ -156,7 +126,6 @@ pub fn get_tokenizer(tokenizer: Option<String>) -> anyhow::Result<Tokenizer> {
} }
Some(file) => file.into(), Some(file) => file.into(),
}; };
Tokenizer::from_file(tokenizer).map_err(E::msg) Tokenizer::from_file(tokenizer).map_err(E::msg)
} }
@ -169,7 +138,6 @@ pub fn tokenize_sequences(
.get_vocab(true) .get_vocab(true)
.get("<|endoftext|>") .get("<|endoftext|>")
.ok_or(E::msg("No pad token"))?; .ok_or(E::msg("No pad token"))?;
let vec_seq = match sequences { let vec_seq = match sequences {
Some(seq) => seq, Some(seq) => seq,
None => vec![ None => vec![
@ -178,16 +146,12 @@ pub fn tokenize_sequences(
"a robot holding a candle".to_string(), "a robot holding a candle".to_string(),
], ],
}; };
let mut tokens = vec![]; let mut tokens = vec![];
for seq in vec_seq.clone() { for seq in vec_seq.clone() {
let encoding = tokenizer.encode(seq, true).map_err(E::msg)?; let encoding = tokenizer.encode(seq, true).map_err(E::msg)?;
tokens.push(encoding.get_ids().to_vec()); tokens.push(encoding.get_ids().to_vec());
} }
let max_len = tokens.iter().map(|v| v.len()).max().unwrap_or(0); let max_len = tokens.iter().map(|v| v.len()).max().unwrap_or(0);
// Pad the sequences to have the same length // Pad the sequences to have the same length
for token_vec in tokens.iter_mut() { for token_vec in tokens.iter_mut() {
let len_diff = max_len - token_vec.len(); let len_diff = max_len - token_vec.len();
@ -195,8 +159,6 @@ pub fn tokenize_sequences(
token_vec.extend(vec![pad_id; len_diff]); token_vec.extend(vec![pad_id; len_diff]);
} }
} }
let input_ids = Tensor::new(tokens, device)?; let input_ids = Tensor::new(tokens, device)?;
Ok((input_ids, vec_seq)) Ok((input_ids, vec_seq))
} }

View File

@ -0,0 +1,24 @@
## SigLIP
SigLIP is multi-modal text-vision model that improves over CLIP by using a sigmoid based loss,
[HuggingFace](https://huggingface.co/google/siglip-base-patch16-224).
### Running an example
```
$ cargo run --features cuda -r --example siglip -
softmax_image_vec: [2.1912122e-14, 2.3624872e-14, 1.0, 1.0, 2.4787932e-8, 3.2784535e-12]
Results for image: candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg
Probability: 0.0000% Text: a cycling race
Probability: 0.0000% Text: a photo of two cats
Probability: 100.0000% Text: a robot holding a candle
Results for image: candle-examples/examples/yolo-v8/assets/bike.jpg
Probability: 100.0000% Text: a cycling race
Probability: 0.0000% Text: a photo of two cats
Probability: 0.0000% Text: a robot holding a candle
```

View File

@ -0,0 +1,153 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::Error as E;
use clap::Parser;
use candle::{DType, Device, Tensor};
use candle_nn::{ops::softmax, VarBuilder};
use candle_transformers::models::siglip;
use tokenizers::Tokenizer;
#[derive(Parser)]
struct Args {
#[arg(long)]
model: Option<String>,
#[arg(long)]
tokenizer: Option<String>,
#[arg(long, use_value_delimiter = true)]
images: Option<Vec<String>>,
#[arg(long)]
cpu: bool,
#[arg(long, use_value_delimiter = true)]
sequences: Option<Vec<String>>,
}
fn load_image<T: AsRef<std::path::Path>>(path: T, image_size: usize) -> anyhow::Result<Tensor> {
let img = image::ImageReader::open(path)?.decode()?;
let (height, width) = (image_size, image_size);
let img = img.resize_to_fill(
width as u32,
height as u32,
image::imageops::FilterType::Triangle,
);
let img = img.to_rgb8();
let img = img.into_raw();
let img = Tensor::from_vec(img, (height, width, 3), &Device::Cpu)?
.permute((2, 0, 1))?
.to_dtype(DType::F32)?
.affine(2. / 255., -1.)?;
Ok(img)
}
fn load_images<T: AsRef<std::path::Path>>(
paths: &Vec<T>,
image_size: usize,
) -> anyhow::Result<Tensor> {
let mut images = vec![];
for path in paths {
let tensor = load_image(path, image_size)?;
images.push(tensor);
}
let images = Tensor::stack(&images, 0)?;
Ok(images)
}
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
let model_file = match args.model {
None => {
let api = hf_hub::api::sync::Api::new()?;
let api = api.model("google/siglip-base-patch16-224".to_string());
api.get("model.safetensors")?
}
Some(model) => model.into(),
};
let tokenizer = get_tokenizer(args.tokenizer)?;
let config = siglip::Config::base_patch16_224();
let device = candle_examples::device(args.cpu)?;
let vec_imgs = match args.images {
Some(imgs) => imgs,
None => vec![
"candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg".to_string(),
"candle-examples/examples/yolo-v8/assets/bike.jpg".to_string(),
],
};
let images = load_images(&vec_imgs, config.vision_config.image_size)?.to_device(&device)?;
let vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file.clone()], DType::F32, &device)? };
let model = siglip::Model::new(&config, vb)?;
let (input_ids, vec_seq) = tokenize_sequences(&config, args.sequences, &tokenizer, &device)?;
let (_logits_per_text, logits_per_image) = model.forward(&images, &input_ids)?;
let softmax_image = softmax(&logits_per_image, 1)?;
let softmax_image_vec = softmax_image.flatten_all()?.to_vec1::<f32>()?;
println!("softmax_image_vec: {:?}", softmax_image_vec);
let probability_vec = softmax_image_vec
.iter()
.map(|v| v * 100.0)
.collect::<Vec<f32>>();
let probability_per_image = probability_vec.len() / vec_imgs.len();
for (i, img) in vec_imgs.iter().enumerate() {
let start = i * probability_per_image;
let end = start + probability_per_image;
let prob = &probability_vec[start..end];
println!("\n\nResults for image: {}\n", img);
for (i, p) in prob.iter().enumerate() {
println!("Probability: {:.4}% Text: {} ", p, vec_seq[i]);
}
}
Ok(())
}
pub fn get_tokenizer(tokenizer: Option<String>) -> anyhow::Result<Tokenizer> {
let tokenizer = match tokenizer {
None => {
let api = hf_hub::api::sync::Api::new()?;
let api = api.model("google/siglip-base-patch16-224".to_string());
api.get("tokenizer.json")?
}
Some(file) => file.into(),
};
Tokenizer::from_file(tokenizer).map_err(E::msg)
}
pub fn tokenize_sequences(
config: &siglip::Config,
sequences: Option<Vec<String>>,
tokenizer: &Tokenizer,
device: &Device,
) -> anyhow::Result<(Tensor, Vec<String>)> {
let pad_id = config.text_config.pad_token_id;
let vec_seq = match sequences {
Some(seq) => seq,
None => vec![
"a cycling race".to_string(),
"a photo of two cats".to_string(),
"a robot holding a candle".to_string(),
],
};
let mut tokens = vec![];
for seq in vec_seq.clone() {
let encoding = tokenizer.encode(seq, true).map_err(E::msg)?;
tokens.push(encoding.get_ids().to_vec());
}
let max_len = config.text_config.max_position_embeddings;
// Pad the sequences to have the same length
for token_vec in tokens.iter_mut() {
let len_diff = max_len - token_vec.len();
if len_diff > 0 {
token_vec.extend(vec![pad_id; len_diff]);
}
}
let input_ids = Tensor::new(tokens, device)?;
Ok((input_ids, vec_seq))
}

View File

@ -92,28 +92,23 @@ impl ClipConfig {
impl ClipModel { impl ClipModel {
pub fn new(vs: candle_nn::VarBuilder, c: &ClipConfig) -> Result<Self> { pub fn new(vs: candle_nn::VarBuilder, c: &ClipConfig) -> Result<Self> {
let text_model = ClipTextTransformer::new(vs.pp("text_model"), &c.text_config)?; let text_model = ClipTextTransformer::new(vs.pp("text_model"), &c.text_config)?;
let vision_model = ClipVisionTransformer::new(vs.pp("vision_model"), &c.vision_config)?; let vision_model = ClipVisionTransformer::new(vs.pp("vision_model"), &c.vision_config)?;
let visual_projection = candle_nn::linear_no_bias( let visual_projection = candle_nn::linear_no_bias(
c.vision_config.embed_dim, c.vision_config.embed_dim,
c.vision_config.projection_dim, c.vision_config.projection_dim,
vs.pp("visual_projection"), vs.pp("visual_projection"),
)?; )?;
let text_projection = candle_nn::linear_no_bias( let text_projection = candle_nn::linear_no_bias(
c.text_config.embed_dim, c.text_config.embed_dim,
c.text_config.projection_dim, c.text_config.projection_dim,
vs.pp("text_projection"), vs.pp("text_projection"),
)?; )?;
// originally nn.Parameter // originally nn.Parameter
let logit_scale = if vs.contains_tensor("logit_scale") { let logit_scale = if vs.contains_tensor("logit_scale") {
vs.get(&[], "logit_scale")? vs.get(&[], "logit_scale")?
} else { } else {
Tensor::new(&[c.logit_scale_init_value], vs.device())? Tensor::new(&[c.logit_scale_init_value], vs.device())?
}; };
Ok(Self { Ok(Self {
text_model, text_model,
vision_model, vision_model,

View File

@ -77,7 +77,7 @@ impl ClipTextEmbeddings {
)?; )?;
let position_ids = let position_ids =
Tensor::arange(0u32, c.max_position_embeddings as u32, vs.device())?.unsqueeze(0)?; Tensor::arange(0u32, c.max_position_embeddings as u32, vs.device())?.unsqueeze(0)?;
Ok(ClipTextEmbeddings { Ok(Self {
token_embedding, token_embedding,
position_embedding, position_embedding,
position_ids, position_ids,
@ -298,7 +298,7 @@ impl ClipTextTransformer {
}) })
} }
// TODO: rewrrite to newer version // TODO: rewrite to newer version
fn build_causal_attention_mask( fn build_causal_attention_mask(
bsz: usize, bsz: usize,
seq_len: usize, seq_len: usize,

View File

@ -11,13 +11,13 @@ use candle_nn::{
BatchNorm, Conv2d, Conv2dConfig, Func, VarBuilder, BatchNorm, Conv2d, Conv2dConfig, Func, VarBuilder,
}; };
#[derive(Clone, Debug)] #[derive(serde::Serialize, serde::Deserialize, Clone, Debug)]
pub struct Config { pub struct Config {
exp_ratio: usize, pub exp_ratio: usize,
in_channels: usize, pub in_channels: usize,
blocks: [usize; 4], pub blocks: [usize; 4],
attn: bool, pub attn: bool,
lkc_use_act: bool, pub lkc_use_act: bool,
} }
impl Config { impl Config {

View File

@ -76,6 +76,7 @@ pub mod rwkv_v5;
pub mod rwkv_v6; pub mod rwkv_v6;
pub mod segformer; pub mod segformer;
pub mod segment_anything; pub mod segment_anything;
pub mod siglip;
pub mod stable_diffusion; pub mod stable_diffusion;
pub mod stable_lm; pub mod stable_lm;
pub mod starcoder2; pub mod starcoder2;

View File

@ -0,0 +1,608 @@
use crate::models::clip::div_l2_norm;
use candle::{IndexOp, Module, Result, Tensor, D};
use candle_nn::{layer_norm, linear, LayerNorm, Linear, VarBuilder};
// https://github.com/huggingface/transformers/blob/2e24ee4dfa39cc0bc264b89edbccc373c8337086/src/transformers/models/siglip/configuration_siglip.py#L27
#[derive(serde::Deserialize, Clone, Debug)]
pub struct TextConfig {
pub vocab_size: usize,
pub hidden_size: usize,
pub intermediate_size: usize,
pub num_hidden_layers: usize,
pub num_attention_heads: usize,
pub max_position_embeddings: usize,
pub hidden_act: candle_nn::Activation,
pub layer_norm_eps: f64,
pub pad_token_id: u32,
pub bos_token_id: u32,
pub eos_token_id: u32,
}
// https://github.com/huggingface/transformers/blob/2e24ee4dfa39cc0bc264b89edbccc373c8337086/src/transformers/models/siglip/configuration_siglip.py#L132
#[derive(serde::Deserialize, Clone, Debug)]
pub struct VisionConfig {
pub hidden_size: usize,
pub intermediate_size: usize,
pub num_hidden_layers: usize,
pub num_attention_heads: usize,
pub num_channels: usize,
pub image_size: usize,
pub patch_size: usize,
pub hidden_act: candle_nn::Activation,
pub layer_norm_eps: f64,
}
trait TransformerConfig {
fn hidden_size(&self) -> usize;
fn intermediate_size(&self) -> usize;
fn num_attention_heads(&self) -> usize;
fn num_hidden_layers(&self) -> usize;
fn layer_norm_eps(&self) -> f64;
fn hidden_act(&self) -> candle_nn::Activation;
}
impl TransformerConfig for TextConfig {
fn hidden_size(&self) -> usize {
self.hidden_size
}
fn intermediate_size(&self) -> usize {
self.intermediate_size
}
fn num_attention_heads(&self) -> usize {
self.num_attention_heads
}
fn num_hidden_layers(&self) -> usize {
self.num_hidden_layers
}
fn layer_norm_eps(&self) -> f64 {
self.layer_norm_eps
}
fn hidden_act(&self) -> candle_nn::Activation {
self.hidden_act
}
}
impl TransformerConfig for VisionConfig {
fn hidden_size(&self) -> usize {
self.hidden_size
}
fn intermediate_size(&self) -> usize {
self.intermediate_size
}
fn num_attention_heads(&self) -> usize {
self.num_attention_heads
}
fn num_hidden_layers(&self) -> usize {
self.num_hidden_layers
}
fn layer_norm_eps(&self) -> f64 {
self.layer_norm_eps
}
fn hidden_act(&self) -> candle_nn::Activation {
self.hidden_act
}
}
// https://github.com/huggingface/transformers/blob/2e24ee4dfa39cc0bc264b89edbccc373c8337086/src/transformers/models/siglip/configuration_siglip.py#L228
#[derive(serde::Deserialize, Clone, Debug)]
pub struct Config {
pub text_config: TextConfig,
pub vision_config: VisionConfig,
}
impl Config {
pub fn base_patch16_224() -> Self {
let text_config = TextConfig {
// https://huggingface.co/google/siglip-base-patch16-224/blob/main/config.json
hidden_size: 768,
intermediate_size: 3072,
num_attention_heads: 12,
vocab_size: 32000,
// Default values.
pad_token_id: 1,
bos_token_id: 49406,
eos_token_id: 49407,
layer_norm_eps: 1e-6,
hidden_act: candle_nn::Activation::GeluPytorchTanh,
max_position_embeddings: 64,
num_hidden_layers: 12,
};
let vision_config = VisionConfig {
patch_size: 16,
// Default values.
hidden_size: 768,
intermediate_size: 3072,
num_hidden_layers: 12,
num_attention_heads: 12,
num_channels: 3,
image_size: 224,
hidden_act: candle_nn::Activation::GeluPytorchTanh,
layer_norm_eps: 1e-6,
};
Self {
text_config,
vision_config,
}
}
}
#[derive(Clone, Debug)]
struct MultiheadAttention {
q_proj: Linear,
k_proj: Linear,
v_proj: Linear,
out_proj: Linear,
num_heads: usize,
}
impl MultiheadAttention {
fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
let h = cfg.hidden_size;
let num_heads = cfg.num_attention_heads;
let w_in_proj = vb.get((3 * h, h), "in_proj_weight")?.chunk(3, 0)?;
let b_in_proj = vb.get(3 * h, "in_proj_bias")?.chunk(3, 0)?;
let q_proj = Linear::new(w_in_proj[0].clone(), Some(b_in_proj[0].clone()));
let k_proj = Linear::new(w_in_proj[1].clone(), Some(b_in_proj[1].clone()));
let v_proj = Linear::new(w_in_proj[2].clone(), Some(b_in_proj[2].clone()));
let out_proj = linear(h, h, vb.pp("out_proj"))?;
Ok(Self {
q_proj,
k_proj,
v_proj,
out_proj,
num_heads,
})
}
fn separate_heads(&self, x: &Tensor) -> Result<Tensor> {
let (b, n, c) = x.dims3()?;
x.reshape((b, n, self.num_heads, c / self.num_heads))?
.transpose(1, 2)?
.contiguous()
}
fn recombine_heads(&self, x: &Tensor) -> Result<Tensor> {
let (b, n_heads, n_tokens, c_per_head) = x.dims4()?;
x.transpose(1, 2)?
.reshape((b, n_tokens, n_heads * c_per_head))
}
fn forward(&self, q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
let q = self.q_proj.forward(&q.contiguous()?)?;
let k = self.k_proj.forward(&k.contiguous()?)?;
let v = self.v_proj.forward(&v.contiguous()?)?;
let q = self.separate_heads(&q)?;
let k = self.separate_heads(&k)?;
let v = self.separate_heads(&v)?;
let (_, _, _, c_per_head) = q.dims4()?;
let attn = (q.matmul(&k.t()?)? / (c_per_head as f64).sqrt())?;
let attn = candle_nn::ops::softmax_last_dim(&attn)?;
let out = attn.matmul(&v)?;
self.recombine_heads(&out)?.apply(&self.out_proj)
}
}
#[derive(Debug, Clone)]
struct MultiheadAttentionPoolingHead {
probe: Tensor,
attention: MultiheadAttention,
layernorm: LayerNorm,
mlp: Mlp,
}
impl MultiheadAttentionPoolingHead {
fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
let mlp = Mlp::new(cfg, vb.pp("mlp"))?;
let layernorm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("layernorm"))?;
let probe = vb.get((1, 1, cfg.hidden_size), "probe")?;
let attention = MultiheadAttention::new(cfg, vb.pp("attention"))?;
Ok(Self {
probe,
attention,
layernorm,
mlp,
})
}
}
impl Module for MultiheadAttentionPoolingHead {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let batch_size = xs.dim(0)?;
let probe = self.probe.repeat((batch_size, 1, 1))?;
let xs = self.attention.forward(&probe, xs, xs)?;
let residual = &xs;
let xs = xs.apply(&self.layernorm)?.apply(&self.mlp)?;
(xs + residual)?.i((.., 0))
}
}
#[derive(Debug, Clone)]
struct Attention {
q_proj: Linear,
k_proj: Linear,
v_proj: Linear,
out_proj: Linear,
num_heads: usize,
head_dim: usize,
scale: f64,
}
impl Attention {
fn new<C: TransformerConfig>(cfg: &C, vb: VarBuilder) -> Result<Self> {
let embed_dim = cfg.hidden_size();
let q_proj = linear(embed_dim, embed_dim, vb.pp("q_proj"))?;
let k_proj = linear(embed_dim, embed_dim, vb.pp("k_proj"))?;
let v_proj = linear(embed_dim, embed_dim, vb.pp("v_proj"))?;
let out_proj = linear(embed_dim, embed_dim, vb.pp("out_proj"))?;
let num_heads = cfg.num_attention_heads();
let head_dim = embed_dim / num_heads;
Ok(Self {
q_proj,
k_proj,
v_proj,
out_proj,
num_heads,
head_dim,
scale: (head_dim as f64).powf(-0.5),
})
}
fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
let (batch_size, q_len, _) = xs.dims3()?;
let query_states = xs.apply(&self.q_proj)?;
let key_states = xs.apply(&self.k_proj)?;
let value_states = xs.apply(&self.v_proj)?;
let shape = (batch_size, q_len, self.num_heads, self.head_dim);
let query_states = query_states.reshape(shape)?.transpose(1, 2)?.contiguous()?;
let key_states = key_states.reshape(shape)?.transpose(1, 2)?.contiguous()?;
let value_states = value_states.reshape(shape)?.transpose(1, 2)?.contiguous()?;
let attn_weights = (query_states.matmul(&key_states.t()?)? * self.scale)?;
let attn_weights = match attention_mask {
None => attn_weights,
Some(mask) => attn_weights.broadcast_add(mask)?,
};
// The original implementation upcasts to f32 but candle_nn::ops::softmax should handle this properly.
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
let attn_outputs = attn_weights
.matmul(&value_states)?
.transpose(1, 2)?
.reshape((batch_size, q_len, ()))?
.apply(&self.out_proj)?;
Ok(attn_outputs)
}
}
// https://github.com/huggingface/transformers/blob/2e24ee4dfa39cc0bc264b89edbccc373c8337086/src/transformers/models/siglip/modeling_siglip.py#L599
#[derive(Debug, Clone)]
struct Mlp {
fc1: Linear,
fc2: Linear,
activation_fn: candle_nn::Activation,
}
impl Mlp {
fn new<C: TransformerConfig>(cfg: &C, vb: VarBuilder) -> Result<Self> {
let hidden_size = cfg.hidden_size();
let intermediate_size = cfg.intermediate_size();
let fc1 = candle_nn::linear(hidden_size, intermediate_size, vb.pp("fc1"))?;
let fc2 = candle_nn::linear(intermediate_size, hidden_size, vb.pp("fc2"))?;
Ok(Self {
fc1,
fc2,
activation_fn: cfg.hidden_act(),
})
}
}
impl Module for Mlp {
fn forward(&self, xs: &candle::Tensor) -> Result<candle::Tensor> {
xs.apply(&self.fc1)?
.apply(&self.activation_fn)?
.apply(&self.fc2)
}
}
// https://github.com/huggingface/transformers/blob/2e24ee4dfa39cc0bc264b89edbccc373c8337086/src/transformers/models/siglip/modeling_siglip.py#L614
#[derive(Debug, Clone)]
struct EncoderLayer {
self_attn: Attention,
layer_norm1: LayerNorm,
mlp: Mlp,
layer_norm2: LayerNorm,
}
impl EncoderLayer {
fn new<C: TransformerConfig>(cfg: &C, vb: VarBuilder) -> Result<Self> {
let hidden_size = cfg.hidden_size();
let layer_norm_eps = cfg.layer_norm_eps();
let self_attn = Attention::new(cfg, vb.pp("self_attn"))?;
let layer_norm1 = layer_norm(hidden_size, layer_norm_eps, vb.pp("layer_norm1"))?;
let mlp = Mlp::new(cfg, vb.pp("mlp"))?;
let layer_norm2 = layer_norm(hidden_size, layer_norm_eps, vb.pp("layer_norm2"))?;
Ok(Self {
self_attn,
layer_norm1,
mlp,
layer_norm2,
})
}
fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
let residual = xs;
let xs = xs.apply(&self.layer_norm1)?;
let xs = self.self_attn.forward(&xs, attention_mask)?;
let xs = (residual + xs)?;
let residual = &xs;
let xs = xs.apply(&self.layer_norm2)?.apply(&self.mlp)?;
let xs = (xs + residual)?;
Ok(xs)
}
}
#[derive(Debug, Clone)]
struct Encoder {
layers: Vec<EncoderLayer>,
}
impl Encoder {
fn new<C: TransformerConfig>(cfg: &C, vb: VarBuilder) -> Result<Self> {
let mut layers = vec![];
let vb = vb.pp("layers");
for layer_idx in 0..cfg.num_hidden_layers() {
let layer = EncoderLayer::new(cfg, vb.pp(layer_idx))?;
layers.push(layer)
}
Ok(Self { layers })
}
fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
let mut xs = xs.clone();
for layer in self.layers.iter() {
xs = layer.forward(&xs, attention_mask)?
}
Ok(xs)
}
}
#[derive(Debug, Clone)]
struct VisionEmbeddings {
patch_embedding: candle_nn::Conv2d,
position_embedding: candle_nn::Embedding,
position_ids: Tensor,
}
impl VisionEmbeddings {
fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
let conv2d_cfg = candle_nn::Conv2dConfig {
stride: cfg.patch_size,
..Default::default()
};
let patch_embedding = candle_nn::conv2d(
cfg.num_channels,
cfg.hidden_size,
cfg.patch_size,
conv2d_cfg,
vb.pp("patch_embedding"),
)?;
let num_patches = (cfg.image_size / cfg.patch_size).pow(2);
let position_ids = Tensor::arange(0, num_patches as i64, vb.device())?;
let position_embedding =
candle_nn::embedding(num_patches, cfg.hidden_size(), vb.pp("position_embedding"))?;
Ok(Self {
patch_embedding,
position_embedding,
position_ids,
})
}
}
impl Module for VisionEmbeddings {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let (_batch, _channels, _height, _width) = xs.dims4()?;
let embeddings = xs.apply(&self.patch_embedding)?;
let embeddings = embeddings.flatten_from(2)?.transpose(1, 2)?;
let position_embedding = self.position_embedding.forward(&self.position_ids)?;
embeddings.broadcast_add(&position_embedding)
}
}
#[derive(Debug, Clone)]
struct VisionTransformer {
embeddings: VisionEmbeddings,
encoder: Encoder,
post_layernorm: LayerNorm,
head: Option<MultiheadAttentionPoolingHead>,
}
impl VisionTransformer {
fn new(cfg: &VisionConfig, use_head: bool, vb: VarBuilder) -> Result<Self> {
let embeddings = VisionEmbeddings::new(cfg, vb.pp("embeddings"))?;
let encoder = Encoder::new(cfg, vb.pp("encoder"))?;
let post_layernorm =
layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("post_layernorm"))?;
let head = if use_head {
Some(MultiheadAttentionPoolingHead::new(cfg, vb.pp("head"))?)
} else {
None
};
Ok(Self {
embeddings,
encoder,
post_layernorm,
head,
})
}
}
impl Module for VisionTransformer {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs = xs.apply(&self.embeddings)?;
let xs = self.encoder.forward(&xs, None)?;
let xs = xs.apply(&self.post_layernorm)?;
match self.head.as_ref() {
None => Ok(xs),
Some(h) => xs.apply(h),
}
}
}
#[derive(Debug, Clone)]
pub struct VisionModel {
vision_model: VisionTransformer,
}
impl VisionModel {
pub fn new(cfg: &VisionConfig, use_head: bool, vb: VarBuilder) -> Result<Self> {
let vision_model = VisionTransformer::new(cfg, use_head, vb)?;
Ok(Self { vision_model })
}
}
impl Module for VisionModel {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
xs.apply(&self.vision_model)
}
}
#[derive(Debug, Clone)]
struct TextEmbeddings {
token_embedding: candle_nn::Embedding,
position_embedding: candle_nn::Embedding,
position_ids: Tensor,
}
impl TextEmbeddings {
fn new(cfg: &TextConfig, vb: VarBuilder) -> Result<Self> {
let token_embedding =
candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("token_embedding"))?;
let position_embedding = candle_nn::embedding(
cfg.max_position_embeddings,
cfg.hidden_size,
vb.pp("position_embedding"),
)?;
let position_ids =
Tensor::arange(0u32, cfg.max_position_embeddings as u32, vb.device())?.unsqueeze(0)?;
Ok(Self {
token_embedding,
position_embedding,
position_ids,
})
}
}
impl Module for TextEmbeddings {
fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
let seq_length = input_ids.dim(D::Minus1)?;
let inputs_embeds = self.token_embedding.forward(input_ids)?;
let position_ids = self.position_ids.narrow(1, 0, seq_length)?;
let position_embedding = self.position_embedding.forward(&position_ids)?;
inputs_embeds.broadcast_add(&position_embedding)
}
}
#[derive(Debug, Clone)]
pub struct TextTransformer {
embeddings: TextEmbeddings,
encoder: Encoder,
final_layer_norm: LayerNorm,
pub head: Linear,
}
impl TextTransformer {
fn new(cfg: &TextConfig, vb: VarBuilder) -> Result<Self> {
let embeddings = TextEmbeddings::new(cfg, vb.pp("embeddings"))?;
let encoder = Encoder::new(cfg, vb.pp("encoder"))?;
let final_layer_norm = layer_norm(
cfg.hidden_size,
cfg.layer_norm_eps,
vb.pp("final_layer_norm"),
)?;
let head = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("head"))?;
Ok(Self {
embeddings,
encoder,
final_layer_norm,
head,
})
}
}
impl Module for TextTransformer {
fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
let (_bsz, seq_len) = input_ids.dims2()?;
let input_ids = self.embeddings.forward(input_ids)?;
let input_ids = self.encoder.forward(&input_ids, None)?;
let last_hidden_state = self.final_layer_norm.forward(&input_ids)?;
last_hidden_state
.i((.., seq_len - 1, ..))?
.contiguous()?
.apply(&self.head)
}
}
#[derive(Debug, Clone)]
pub struct TextModel {
pub text_model: TextTransformer,
}
impl TextModel {
pub fn new(cfg: &TextConfig, vb: VarBuilder) -> Result<Self> {
let text_model = TextTransformer::new(cfg, vb)?;
Ok(Self { text_model })
}
}
impl Module for TextModel {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
xs.apply(&self.text_model)
}
}
#[derive(Clone, Debug)]
pub struct Model {
text_model: TextModel,
vision_model: VisionModel,
logit_bias: Tensor,
logit_scale: Tensor,
}
impl Model {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let text_model = TextModel::new(&cfg.text_config, vb.pp("text_model"))?;
let vision_model = VisionModel::new(&cfg.vision_config, true, vb.pp("vision_model"))?;
let logit_scale = vb.get(&[1], "logit_scale")?;
let logit_bias = vb.get(&[1], "logit_bias")?;
Ok(Self {
text_model,
vision_model,
logit_bias,
logit_scale,
})
}
pub fn get_text_features(&self, input_ids: &Tensor) -> Result<Tensor> {
input_ids.apply(&self.text_model)
}
pub fn get_image_features(&self, pixel_values: &Tensor) -> Result<Tensor> {
pixel_values.apply(&self.vision_model)
}
pub fn forward(&self, pixel_values: &Tensor, input_ids: &Tensor) -> Result<(Tensor, Tensor)> {
let image_features = self.get_image_features(pixel_values)?;
let text_features = self.get_text_features(input_ids)?;
let image_features_normalized = div_l2_norm(&image_features)?;
let text_features_normalized = div_l2_norm(&text_features)?;
let logits_per_text = text_features_normalized.matmul(&image_features_normalized.t()?)?;
let logit_scale = self.logit_scale.exp()?;
let logits_per_text = logits_per_text
.broadcast_mul(&logit_scale)?
.broadcast_add(&self.logit_bias)?;
let logits_per_image = logits_per_text.t()?;
Ok((logits_per_text, logits_per_image))
}
}