mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
Add the SigLIP model. (#2515)
* Add the SigLIP model. * Add more to the forward pass of the vision model. * Complete the forward pass. * Add the siglip example. * Fix. * Another fix. * Get everything in place. * Add a readme.
This commit is contained in:
@ -12,7 +12,6 @@ use candle_nn::{ops::softmax, VarBuilder};
|
|||||||
use candle_transformers::models::clip;
|
use candle_transformers::models::clip;
|
||||||
|
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
use tracing::info;
|
|
||||||
|
|
||||||
#[derive(Parser)]
|
#[derive(Parser)]
|
||||||
struct Args {
|
struct Args {
|
||||||
@ -40,15 +39,12 @@ fn load_image<T: AsRef<std::path::Path>>(path: T, image_size: usize) -> anyhow::
|
|||||||
height as u32,
|
height as u32,
|
||||||
image::imageops::FilterType::Triangle,
|
image::imageops::FilterType::Triangle,
|
||||||
);
|
);
|
||||||
|
|
||||||
let img = img.to_rgb8();
|
let img = img.to_rgb8();
|
||||||
|
|
||||||
let img = img.into_raw();
|
let img = img.into_raw();
|
||||||
let img = Tensor::from_vec(img, (height, width, 3), &Device::Cpu)?
|
let img = Tensor::from_vec(img, (height, width, 3), &Device::Cpu)?
|
||||||
.permute((2, 0, 1))?
|
.permute((2, 0, 1))?
|
||||||
.to_dtype(DType::F32)?
|
.to_dtype(DType::F32)?
|
||||||
.affine(2. / 255., -1.)?;
|
.affine(2. / 255., -1.)?;
|
||||||
// .unsqueeze(0)?;
|
|
||||||
Ok(img)
|
Ok(img)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -57,24 +53,16 @@ fn load_images<T: AsRef<std::path::Path>>(
|
|||||||
image_size: usize,
|
image_size: usize,
|
||||||
) -> anyhow::Result<Tensor> {
|
) -> anyhow::Result<Tensor> {
|
||||||
let mut images = vec![];
|
let mut images = vec![];
|
||||||
|
|
||||||
for path in paths {
|
for path in paths {
|
||||||
let tensor = load_image(path, image_size)?;
|
let tensor = load_image(path, image_size)?;
|
||||||
images.push(tensor);
|
images.push(tensor);
|
||||||
}
|
}
|
||||||
|
|
||||||
let images = Tensor::stack(&images, 0)?;
|
let images = Tensor::stack(&images, 0)?;
|
||||||
|
|
||||||
Ok(images)
|
Ok(images)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn main() -> anyhow::Result<()> {
|
pub fn main() -> anyhow::Result<()> {
|
||||||
// std::env::set_var("RUST_BACKTRACE", "full");
|
|
||||||
|
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
|
|
||||||
tracing_subscriber::fmt::init();
|
|
||||||
|
|
||||||
let model_file = match args.model {
|
let model_file = match args.model {
|
||||||
None => {
|
None => {
|
||||||
let api = hf_hub::api::sync::Api::new()?;
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
@ -89,13 +77,9 @@ pub fn main() -> anyhow::Result<()> {
|
|||||||
}
|
}
|
||||||
Some(model) => model.into(),
|
Some(model) => model.into(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let tokenizer = get_tokenizer(args.tokenizer)?;
|
let tokenizer = get_tokenizer(args.tokenizer)?;
|
||||||
|
|
||||||
let config = clip::ClipConfig::vit_base_patch32();
|
let config = clip::ClipConfig::vit_base_patch32();
|
||||||
|
|
||||||
let device = candle_examples::device(args.cpu)?;
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
|
||||||
let vec_imgs = match args.images {
|
let vec_imgs = match args.images {
|
||||||
Some(imgs) => imgs,
|
Some(imgs) => imgs,
|
||||||
None => vec![
|
None => vec![
|
||||||
@ -103,43 +87,29 @@ pub fn main() -> anyhow::Result<()> {
|
|||||||
"candle-examples/examples/yolo-v8/assets/bike.jpg".to_string(),
|
"candle-examples/examples/yolo-v8/assets/bike.jpg".to_string(),
|
||||||
],
|
],
|
||||||
};
|
};
|
||||||
|
|
||||||
// let image = load_image(args.image, config.image_size)?.to_device(&device)?;
|
|
||||||
let images = load_images(&vec_imgs, config.image_size)?.to_device(&device)?;
|
let images = load_images(&vec_imgs, config.image_size)?.to_device(&device)?;
|
||||||
|
|
||||||
let vb =
|
let vb =
|
||||||
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file.clone()], DType::F32, &device)? };
|
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file.clone()], DType::F32, &device)? };
|
||||||
|
|
||||||
let model = clip::ClipModel::new(vb, &config)?;
|
let model = clip::ClipModel::new(vb, &config)?;
|
||||||
|
|
||||||
let (input_ids, vec_seq) = tokenize_sequences(args.sequences, &tokenizer, &device)?;
|
let (input_ids, vec_seq) = tokenize_sequences(args.sequences, &tokenizer, &device)?;
|
||||||
|
|
||||||
let (_logits_per_text, logits_per_image) = model.forward(&images, &input_ids)?;
|
let (_logits_per_text, logits_per_image) = model.forward(&images, &input_ids)?;
|
||||||
|
|
||||||
let softmax_image = softmax(&logits_per_image, 1)?;
|
let softmax_image = softmax(&logits_per_image, 1)?;
|
||||||
|
|
||||||
let softmax_image_vec = softmax_image.flatten_all()?.to_vec1::<f32>()?;
|
let softmax_image_vec = softmax_image.flatten_all()?.to_vec1::<f32>()?;
|
||||||
|
println!("softmax_image_vec: {:?}", softmax_image_vec);
|
||||||
info!("softmax_image_vec: {:?}", softmax_image_vec);
|
|
||||||
|
|
||||||
let probability_vec = softmax_image_vec
|
let probability_vec = softmax_image_vec
|
||||||
.iter()
|
.iter()
|
||||||
.map(|v| v * 100.0)
|
.map(|v| v * 100.0)
|
||||||
.collect::<Vec<f32>>();
|
.collect::<Vec<f32>>();
|
||||||
|
|
||||||
let probability_per_image = probability_vec.len() / vec_imgs.len();
|
let probability_per_image = probability_vec.len() / vec_imgs.len();
|
||||||
|
|
||||||
for (i, img) in vec_imgs.iter().enumerate() {
|
for (i, img) in vec_imgs.iter().enumerate() {
|
||||||
let start = i * probability_per_image;
|
let start = i * probability_per_image;
|
||||||
let end = start + probability_per_image;
|
let end = start + probability_per_image;
|
||||||
let prob = &probability_vec[start..end];
|
let prob = &probability_vec[start..end];
|
||||||
info!("\n\nResults for image: {}\n", img);
|
println!("\n\nResults for image: {}\n", img);
|
||||||
|
|
||||||
for (i, p) in prob.iter().enumerate() {
|
for (i, p) in prob.iter().enumerate() {
|
||||||
info!("Probability: {:.4}% Text: {} ", p, vec_seq[i]);
|
println!("Probability: {:.4}% Text: {} ", p, vec_seq[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -156,7 +126,6 @@ pub fn get_tokenizer(tokenizer: Option<String>) -> anyhow::Result<Tokenizer> {
|
|||||||
}
|
}
|
||||||
Some(file) => file.into(),
|
Some(file) => file.into(),
|
||||||
};
|
};
|
||||||
|
|
||||||
Tokenizer::from_file(tokenizer).map_err(E::msg)
|
Tokenizer::from_file(tokenizer).map_err(E::msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -169,7 +138,6 @@ pub fn tokenize_sequences(
|
|||||||
.get_vocab(true)
|
.get_vocab(true)
|
||||||
.get("<|endoftext|>")
|
.get("<|endoftext|>")
|
||||||
.ok_or(E::msg("No pad token"))?;
|
.ok_or(E::msg("No pad token"))?;
|
||||||
|
|
||||||
let vec_seq = match sequences {
|
let vec_seq = match sequences {
|
||||||
Some(seq) => seq,
|
Some(seq) => seq,
|
||||||
None => vec![
|
None => vec![
|
||||||
@ -178,16 +146,12 @@ pub fn tokenize_sequences(
|
|||||||
"a robot holding a candle".to_string(),
|
"a robot holding a candle".to_string(),
|
||||||
],
|
],
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut tokens = vec![];
|
let mut tokens = vec![];
|
||||||
|
|
||||||
for seq in vec_seq.clone() {
|
for seq in vec_seq.clone() {
|
||||||
let encoding = tokenizer.encode(seq, true).map_err(E::msg)?;
|
let encoding = tokenizer.encode(seq, true).map_err(E::msg)?;
|
||||||
tokens.push(encoding.get_ids().to_vec());
|
tokens.push(encoding.get_ids().to_vec());
|
||||||
}
|
}
|
||||||
|
|
||||||
let max_len = tokens.iter().map(|v| v.len()).max().unwrap_or(0);
|
let max_len = tokens.iter().map(|v| v.len()).max().unwrap_or(0);
|
||||||
|
|
||||||
// Pad the sequences to have the same length
|
// Pad the sequences to have the same length
|
||||||
for token_vec in tokens.iter_mut() {
|
for token_vec in tokens.iter_mut() {
|
||||||
let len_diff = max_len - token_vec.len();
|
let len_diff = max_len - token_vec.len();
|
||||||
@ -195,8 +159,6 @@ pub fn tokenize_sequences(
|
|||||||
token_vec.extend(vec![pad_id; len_diff]);
|
token_vec.extend(vec![pad_id; len_diff]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let input_ids = Tensor::new(tokens, device)?;
|
let input_ids = Tensor::new(tokens, device)?;
|
||||||
|
|
||||||
Ok((input_ids, vec_seq))
|
Ok((input_ids, vec_seq))
|
||||||
}
|
}
|
||||||
|
24
candle-examples/examples/siglip/README.md
Normal file
24
candle-examples/examples/siglip/README.md
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
## SigLIP
|
||||||
|
|
||||||
|
SigLIP is multi-modal text-vision model that improves over CLIP by using a sigmoid based loss,
|
||||||
|
[HuggingFace](https://huggingface.co/google/siglip-base-patch16-224).
|
||||||
|
|
||||||
|
### Running an example
|
||||||
|
```
|
||||||
|
$ cargo run --features cuda -r --example siglip -
|
||||||
|
softmax_image_vec: [2.1912122e-14, 2.3624872e-14, 1.0, 1.0, 2.4787932e-8, 3.2784535e-12]
|
||||||
|
|
||||||
|
|
||||||
|
Results for image: candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg
|
||||||
|
|
||||||
|
Probability: 0.0000% Text: a cycling race
|
||||||
|
Probability: 0.0000% Text: a photo of two cats
|
||||||
|
Probability: 100.0000% Text: a robot holding a candle
|
||||||
|
|
||||||
|
|
||||||
|
Results for image: candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||||
|
|
||||||
|
Probability: 100.0000% Text: a cycling race
|
||||||
|
Probability: 0.0000% Text: a photo of two cats
|
||||||
|
Probability: 0.0000% Text: a robot holding a candle
|
||||||
|
```
|
153
candle-examples/examples/siglip/main.rs
Normal file
153
candle-examples/examples/siglip/main.rs
Normal file
@ -0,0 +1,153 @@
|
|||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
|
||||||
|
use anyhow::Error as E;
|
||||||
|
use clap::Parser;
|
||||||
|
|
||||||
|
use candle::{DType, Device, Tensor};
|
||||||
|
use candle_nn::{ops::softmax, VarBuilder};
|
||||||
|
use candle_transformers::models::siglip;
|
||||||
|
|
||||||
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
|
#[derive(Parser)]
|
||||||
|
struct Args {
|
||||||
|
#[arg(long)]
|
||||||
|
model: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
tokenizer: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long, use_value_delimiter = true)]
|
||||||
|
images: Option<Vec<String>>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
|
||||||
|
#[arg(long, use_value_delimiter = true)]
|
||||||
|
sequences: Option<Vec<String>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn load_image<T: AsRef<std::path::Path>>(path: T, image_size: usize) -> anyhow::Result<Tensor> {
|
||||||
|
let img = image::ImageReader::open(path)?.decode()?;
|
||||||
|
let (height, width) = (image_size, image_size);
|
||||||
|
let img = img.resize_to_fill(
|
||||||
|
width as u32,
|
||||||
|
height as u32,
|
||||||
|
image::imageops::FilterType::Triangle,
|
||||||
|
);
|
||||||
|
let img = img.to_rgb8();
|
||||||
|
let img = img.into_raw();
|
||||||
|
let img = Tensor::from_vec(img, (height, width, 3), &Device::Cpu)?
|
||||||
|
.permute((2, 0, 1))?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.affine(2. / 255., -1.)?;
|
||||||
|
Ok(img)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn load_images<T: AsRef<std::path::Path>>(
|
||||||
|
paths: &Vec<T>,
|
||||||
|
image_size: usize,
|
||||||
|
) -> anyhow::Result<Tensor> {
|
||||||
|
let mut images = vec![];
|
||||||
|
for path in paths {
|
||||||
|
let tensor = load_image(path, image_size)?;
|
||||||
|
images.push(tensor);
|
||||||
|
}
|
||||||
|
let images = Tensor::stack(&images, 0)?;
|
||||||
|
Ok(images)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn main() -> anyhow::Result<()> {
|
||||||
|
let args = Args::parse();
|
||||||
|
let model_file = match args.model {
|
||||||
|
None => {
|
||||||
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
|
let api = api.model("google/siglip-base-patch16-224".to_string());
|
||||||
|
api.get("model.safetensors")?
|
||||||
|
}
|
||||||
|
Some(model) => model.into(),
|
||||||
|
};
|
||||||
|
let tokenizer = get_tokenizer(args.tokenizer)?;
|
||||||
|
let config = siglip::Config::base_patch16_224();
|
||||||
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
let vec_imgs = match args.images {
|
||||||
|
Some(imgs) => imgs,
|
||||||
|
None => vec![
|
||||||
|
"candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg".to_string(),
|
||||||
|
"candle-examples/examples/yolo-v8/assets/bike.jpg".to_string(),
|
||||||
|
],
|
||||||
|
};
|
||||||
|
let images = load_images(&vec_imgs, config.vision_config.image_size)?.to_device(&device)?;
|
||||||
|
let vb =
|
||||||
|
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file.clone()], DType::F32, &device)? };
|
||||||
|
let model = siglip::Model::new(&config, vb)?;
|
||||||
|
let (input_ids, vec_seq) = tokenize_sequences(&config, args.sequences, &tokenizer, &device)?;
|
||||||
|
let (_logits_per_text, logits_per_image) = model.forward(&images, &input_ids)?;
|
||||||
|
let softmax_image = softmax(&logits_per_image, 1)?;
|
||||||
|
let softmax_image_vec = softmax_image.flatten_all()?.to_vec1::<f32>()?;
|
||||||
|
println!("softmax_image_vec: {:?}", softmax_image_vec);
|
||||||
|
let probability_vec = softmax_image_vec
|
||||||
|
.iter()
|
||||||
|
.map(|v| v * 100.0)
|
||||||
|
.collect::<Vec<f32>>();
|
||||||
|
let probability_per_image = probability_vec.len() / vec_imgs.len();
|
||||||
|
for (i, img) in vec_imgs.iter().enumerate() {
|
||||||
|
let start = i * probability_per_image;
|
||||||
|
let end = start + probability_per_image;
|
||||||
|
let prob = &probability_vec[start..end];
|
||||||
|
println!("\n\nResults for image: {}\n", img);
|
||||||
|
for (i, p) in prob.iter().enumerate() {
|
||||||
|
println!("Probability: {:.4}% Text: {} ", p, vec_seq[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_tokenizer(tokenizer: Option<String>) -> anyhow::Result<Tokenizer> {
|
||||||
|
let tokenizer = match tokenizer {
|
||||||
|
None => {
|
||||||
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
|
let api = api.model("google/siglip-base-patch16-224".to_string());
|
||||||
|
api.get("tokenizer.json")?
|
||||||
|
}
|
||||||
|
Some(file) => file.into(),
|
||||||
|
};
|
||||||
|
|
||||||
|
Tokenizer::from_file(tokenizer).map_err(E::msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn tokenize_sequences(
|
||||||
|
config: &siglip::Config,
|
||||||
|
sequences: Option<Vec<String>>,
|
||||||
|
tokenizer: &Tokenizer,
|
||||||
|
device: &Device,
|
||||||
|
) -> anyhow::Result<(Tensor, Vec<String>)> {
|
||||||
|
let pad_id = config.text_config.pad_token_id;
|
||||||
|
let vec_seq = match sequences {
|
||||||
|
Some(seq) => seq,
|
||||||
|
None => vec![
|
||||||
|
"a cycling race".to_string(),
|
||||||
|
"a photo of two cats".to_string(),
|
||||||
|
"a robot holding a candle".to_string(),
|
||||||
|
],
|
||||||
|
};
|
||||||
|
let mut tokens = vec![];
|
||||||
|
for seq in vec_seq.clone() {
|
||||||
|
let encoding = tokenizer.encode(seq, true).map_err(E::msg)?;
|
||||||
|
tokens.push(encoding.get_ids().to_vec());
|
||||||
|
}
|
||||||
|
let max_len = config.text_config.max_position_embeddings;
|
||||||
|
// Pad the sequences to have the same length
|
||||||
|
for token_vec in tokens.iter_mut() {
|
||||||
|
let len_diff = max_len - token_vec.len();
|
||||||
|
if len_diff > 0 {
|
||||||
|
token_vec.extend(vec![pad_id; len_diff]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let input_ids = Tensor::new(tokens, device)?;
|
||||||
|
Ok((input_ids, vec_seq))
|
||||||
|
}
|
@ -92,28 +92,23 @@ impl ClipConfig {
|
|||||||
impl ClipModel {
|
impl ClipModel {
|
||||||
pub fn new(vs: candle_nn::VarBuilder, c: &ClipConfig) -> Result<Self> {
|
pub fn new(vs: candle_nn::VarBuilder, c: &ClipConfig) -> Result<Self> {
|
||||||
let text_model = ClipTextTransformer::new(vs.pp("text_model"), &c.text_config)?;
|
let text_model = ClipTextTransformer::new(vs.pp("text_model"), &c.text_config)?;
|
||||||
|
|
||||||
let vision_model = ClipVisionTransformer::new(vs.pp("vision_model"), &c.vision_config)?;
|
let vision_model = ClipVisionTransformer::new(vs.pp("vision_model"), &c.vision_config)?;
|
||||||
|
|
||||||
let visual_projection = candle_nn::linear_no_bias(
|
let visual_projection = candle_nn::linear_no_bias(
|
||||||
c.vision_config.embed_dim,
|
c.vision_config.embed_dim,
|
||||||
c.vision_config.projection_dim,
|
c.vision_config.projection_dim,
|
||||||
vs.pp("visual_projection"),
|
vs.pp("visual_projection"),
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
let text_projection = candle_nn::linear_no_bias(
|
let text_projection = candle_nn::linear_no_bias(
|
||||||
c.text_config.embed_dim,
|
c.text_config.embed_dim,
|
||||||
c.text_config.projection_dim,
|
c.text_config.projection_dim,
|
||||||
vs.pp("text_projection"),
|
vs.pp("text_projection"),
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
// originally nn.Parameter
|
// originally nn.Parameter
|
||||||
let logit_scale = if vs.contains_tensor("logit_scale") {
|
let logit_scale = if vs.contains_tensor("logit_scale") {
|
||||||
vs.get(&[], "logit_scale")?
|
vs.get(&[], "logit_scale")?
|
||||||
} else {
|
} else {
|
||||||
Tensor::new(&[c.logit_scale_init_value], vs.device())?
|
Tensor::new(&[c.logit_scale_init_value], vs.device())?
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
text_model,
|
text_model,
|
||||||
vision_model,
|
vision_model,
|
||||||
|
@ -77,7 +77,7 @@ impl ClipTextEmbeddings {
|
|||||||
)?;
|
)?;
|
||||||
let position_ids =
|
let position_ids =
|
||||||
Tensor::arange(0u32, c.max_position_embeddings as u32, vs.device())?.unsqueeze(0)?;
|
Tensor::arange(0u32, c.max_position_embeddings as u32, vs.device())?.unsqueeze(0)?;
|
||||||
Ok(ClipTextEmbeddings {
|
Ok(Self {
|
||||||
token_embedding,
|
token_embedding,
|
||||||
position_embedding,
|
position_embedding,
|
||||||
position_ids,
|
position_ids,
|
||||||
@ -298,7 +298,7 @@ impl ClipTextTransformer {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: rewrrite to newer version
|
// TODO: rewrite to newer version
|
||||||
fn build_causal_attention_mask(
|
fn build_causal_attention_mask(
|
||||||
bsz: usize,
|
bsz: usize,
|
||||||
seq_len: usize,
|
seq_len: usize,
|
||||||
|
@ -11,13 +11,13 @@ use candle_nn::{
|
|||||||
BatchNorm, Conv2d, Conv2dConfig, Func, VarBuilder,
|
BatchNorm, Conv2d, Conv2dConfig, Func, VarBuilder,
|
||||||
};
|
};
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
#[derive(serde::Serialize, serde::Deserialize, Clone, Debug)]
|
||||||
pub struct Config {
|
pub struct Config {
|
||||||
exp_ratio: usize,
|
pub exp_ratio: usize,
|
||||||
in_channels: usize,
|
pub in_channels: usize,
|
||||||
blocks: [usize; 4],
|
pub blocks: [usize; 4],
|
||||||
attn: bool,
|
pub attn: bool,
|
||||||
lkc_use_act: bool,
|
pub lkc_use_act: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Config {
|
impl Config {
|
||||||
|
@ -76,6 +76,7 @@ pub mod rwkv_v5;
|
|||||||
pub mod rwkv_v6;
|
pub mod rwkv_v6;
|
||||||
pub mod segformer;
|
pub mod segformer;
|
||||||
pub mod segment_anything;
|
pub mod segment_anything;
|
||||||
|
pub mod siglip;
|
||||||
pub mod stable_diffusion;
|
pub mod stable_diffusion;
|
||||||
pub mod stable_lm;
|
pub mod stable_lm;
|
||||||
pub mod starcoder2;
|
pub mod starcoder2;
|
||||||
|
608
candle-transformers/src/models/siglip.rs
Normal file
608
candle-transformers/src/models/siglip.rs
Normal file
@ -0,0 +1,608 @@
|
|||||||
|
use crate::models::clip::div_l2_norm;
|
||||||
|
use candle::{IndexOp, Module, Result, Tensor, D};
|
||||||
|
use candle_nn::{layer_norm, linear, LayerNorm, Linear, VarBuilder};
|
||||||
|
|
||||||
|
// https://github.com/huggingface/transformers/blob/2e24ee4dfa39cc0bc264b89edbccc373c8337086/src/transformers/models/siglip/configuration_siglip.py#L27
|
||||||
|
#[derive(serde::Deserialize, Clone, Debug)]
|
||||||
|
pub struct TextConfig {
|
||||||
|
pub vocab_size: usize,
|
||||||
|
pub hidden_size: usize,
|
||||||
|
pub intermediate_size: usize,
|
||||||
|
pub num_hidden_layers: usize,
|
||||||
|
pub num_attention_heads: usize,
|
||||||
|
pub max_position_embeddings: usize,
|
||||||
|
pub hidden_act: candle_nn::Activation,
|
||||||
|
pub layer_norm_eps: f64,
|
||||||
|
pub pad_token_id: u32,
|
||||||
|
pub bos_token_id: u32,
|
||||||
|
pub eos_token_id: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://github.com/huggingface/transformers/blob/2e24ee4dfa39cc0bc264b89edbccc373c8337086/src/transformers/models/siglip/configuration_siglip.py#L132
|
||||||
|
#[derive(serde::Deserialize, Clone, Debug)]
|
||||||
|
pub struct VisionConfig {
|
||||||
|
pub hidden_size: usize,
|
||||||
|
pub intermediate_size: usize,
|
||||||
|
pub num_hidden_layers: usize,
|
||||||
|
pub num_attention_heads: usize,
|
||||||
|
pub num_channels: usize,
|
||||||
|
pub image_size: usize,
|
||||||
|
pub patch_size: usize,
|
||||||
|
pub hidden_act: candle_nn::Activation,
|
||||||
|
pub layer_norm_eps: f64,
|
||||||
|
}
|
||||||
|
|
||||||
|
trait TransformerConfig {
|
||||||
|
fn hidden_size(&self) -> usize;
|
||||||
|
fn intermediate_size(&self) -> usize;
|
||||||
|
fn num_attention_heads(&self) -> usize;
|
||||||
|
fn num_hidden_layers(&self) -> usize;
|
||||||
|
fn layer_norm_eps(&self) -> f64;
|
||||||
|
fn hidden_act(&self) -> candle_nn::Activation;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TransformerConfig for TextConfig {
|
||||||
|
fn hidden_size(&self) -> usize {
|
||||||
|
self.hidden_size
|
||||||
|
}
|
||||||
|
fn intermediate_size(&self) -> usize {
|
||||||
|
self.intermediate_size
|
||||||
|
}
|
||||||
|
fn num_attention_heads(&self) -> usize {
|
||||||
|
self.num_attention_heads
|
||||||
|
}
|
||||||
|
fn num_hidden_layers(&self) -> usize {
|
||||||
|
self.num_hidden_layers
|
||||||
|
}
|
||||||
|
fn layer_norm_eps(&self) -> f64 {
|
||||||
|
self.layer_norm_eps
|
||||||
|
}
|
||||||
|
fn hidden_act(&self) -> candle_nn::Activation {
|
||||||
|
self.hidden_act
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TransformerConfig for VisionConfig {
|
||||||
|
fn hidden_size(&self) -> usize {
|
||||||
|
self.hidden_size
|
||||||
|
}
|
||||||
|
fn intermediate_size(&self) -> usize {
|
||||||
|
self.intermediate_size
|
||||||
|
}
|
||||||
|
fn num_attention_heads(&self) -> usize {
|
||||||
|
self.num_attention_heads
|
||||||
|
}
|
||||||
|
fn num_hidden_layers(&self) -> usize {
|
||||||
|
self.num_hidden_layers
|
||||||
|
}
|
||||||
|
fn layer_norm_eps(&self) -> f64 {
|
||||||
|
self.layer_norm_eps
|
||||||
|
}
|
||||||
|
fn hidden_act(&self) -> candle_nn::Activation {
|
||||||
|
self.hidden_act
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://github.com/huggingface/transformers/blob/2e24ee4dfa39cc0bc264b89edbccc373c8337086/src/transformers/models/siglip/configuration_siglip.py#L228
|
||||||
|
#[derive(serde::Deserialize, Clone, Debug)]
|
||||||
|
pub struct Config {
|
||||||
|
pub text_config: TextConfig,
|
||||||
|
pub vision_config: VisionConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Config {
|
||||||
|
pub fn base_patch16_224() -> Self {
|
||||||
|
let text_config = TextConfig {
|
||||||
|
// https://huggingface.co/google/siglip-base-patch16-224/blob/main/config.json
|
||||||
|
hidden_size: 768,
|
||||||
|
intermediate_size: 3072,
|
||||||
|
num_attention_heads: 12,
|
||||||
|
vocab_size: 32000,
|
||||||
|
// Default values.
|
||||||
|
pad_token_id: 1,
|
||||||
|
bos_token_id: 49406,
|
||||||
|
eos_token_id: 49407,
|
||||||
|
layer_norm_eps: 1e-6,
|
||||||
|
hidden_act: candle_nn::Activation::GeluPytorchTanh,
|
||||||
|
max_position_embeddings: 64,
|
||||||
|
num_hidden_layers: 12,
|
||||||
|
};
|
||||||
|
let vision_config = VisionConfig {
|
||||||
|
patch_size: 16,
|
||||||
|
// Default values.
|
||||||
|
hidden_size: 768,
|
||||||
|
intermediate_size: 3072,
|
||||||
|
num_hidden_layers: 12,
|
||||||
|
num_attention_heads: 12,
|
||||||
|
num_channels: 3,
|
||||||
|
image_size: 224,
|
||||||
|
hidden_act: candle_nn::Activation::GeluPytorchTanh,
|
||||||
|
layer_norm_eps: 1e-6,
|
||||||
|
};
|
||||||
|
Self {
|
||||||
|
text_config,
|
||||||
|
vision_config,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
struct MultiheadAttention {
|
||||||
|
q_proj: Linear,
|
||||||
|
k_proj: Linear,
|
||||||
|
v_proj: Linear,
|
||||||
|
out_proj: Linear,
|
||||||
|
num_heads: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MultiheadAttention {
|
||||||
|
fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let h = cfg.hidden_size;
|
||||||
|
let num_heads = cfg.num_attention_heads;
|
||||||
|
let w_in_proj = vb.get((3 * h, h), "in_proj_weight")?.chunk(3, 0)?;
|
||||||
|
let b_in_proj = vb.get(3 * h, "in_proj_bias")?.chunk(3, 0)?;
|
||||||
|
let q_proj = Linear::new(w_in_proj[0].clone(), Some(b_in_proj[0].clone()));
|
||||||
|
let k_proj = Linear::new(w_in_proj[1].clone(), Some(b_in_proj[1].clone()));
|
||||||
|
let v_proj = Linear::new(w_in_proj[2].clone(), Some(b_in_proj[2].clone()));
|
||||||
|
let out_proj = linear(h, h, vb.pp("out_proj"))?;
|
||||||
|
Ok(Self {
|
||||||
|
q_proj,
|
||||||
|
k_proj,
|
||||||
|
v_proj,
|
||||||
|
out_proj,
|
||||||
|
num_heads,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn separate_heads(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
|
let (b, n, c) = x.dims3()?;
|
||||||
|
x.reshape((b, n, self.num_heads, c / self.num_heads))?
|
||||||
|
.transpose(1, 2)?
|
||||||
|
.contiguous()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn recombine_heads(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
|
let (b, n_heads, n_tokens, c_per_head) = x.dims4()?;
|
||||||
|
x.transpose(1, 2)?
|
||||||
|
.reshape((b, n_tokens, n_heads * c_per_head))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(&self, q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
|
||||||
|
let q = self.q_proj.forward(&q.contiguous()?)?;
|
||||||
|
let k = self.k_proj.forward(&k.contiguous()?)?;
|
||||||
|
let v = self.v_proj.forward(&v.contiguous()?)?;
|
||||||
|
|
||||||
|
let q = self.separate_heads(&q)?;
|
||||||
|
let k = self.separate_heads(&k)?;
|
||||||
|
let v = self.separate_heads(&v)?;
|
||||||
|
|
||||||
|
let (_, _, _, c_per_head) = q.dims4()?;
|
||||||
|
let attn = (q.matmul(&k.t()?)? / (c_per_head as f64).sqrt())?;
|
||||||
|
let attn = candle_nn::ops::softmax_last_dim(&attn)?;
|
||||||
|
|
||||||
|
let out = attn.matmul(&v)?;
|
||||||
|
self.recombine_heads(&out)?.apply(&self.out_proj)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct MultiheadAttentionPoolingHead {
|
||||||
|
probe: Tensor,
|
||||||
|
attention: MultiheadAttention,
|
||||||
|
layernorm: LayerNorm,
|
||||||
|
mlp: Mlp,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MultiheadAttentionPoolingHead {
|
||||||
|
fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let mlp = Mlp::new(cfg, vb.pp("mlp"))?;
|
||||||
|
let layernorm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("layernorm"))?;
|
||||||
|
let probe = vb.get((1, 1, cfg.hidden_size), "probe")?;
|
||||||
|
let attention = MultiheadAttention::new(cfg, vb.pp("attention"))?;
|
||||||
|
Ok(Self {
|
||||||
|
probe,
|
||||||
|
attention,
|
||||||
|
layernorm,
|
||||||
|
mlp,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for MultiheadAttentionPoolingHead {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let batch_size = xs.dim(0)?;
|
||||||
|
let probe = self.probe.repeat((batch_size, 1, 1))?;
|
||||||
|
let xs = self.attention.forward(&probe, xs, xs)?;
|
||||||
|
let residual = &xs;
|
||||||
|
let xs = xs.apply(&self.layernorm)?.apply(&self.mlp)?;
|
||||||
|
(xs + residual)?.i((.., 0))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct Attention {
|
||||||
|
q_proj: Linear,
|
||||||
|
k_proj: Linear,
|
||||||
|
v_proj: Linear,
|
||||||
|
out_proj: Linear,
|
||||||
|
num_heads: usize,
|
||||||
|
head_dim: usize,
|
||||||
|
scale: f64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Attention {
|
||||||
|
fn new<C: TransformerConfig>(cfg: &C, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let embed_dim = cfg.hidden_size();
|
||||||
|
let q_proj = linear(embed_dim, embed_dim, vb.pp("q_proj"))?;
|
||||||
|
let k_proj = linear(embed_dim, embed_dim, vb.pp("k_proj"))?;
|
||||||
|
let v_proj = linear(embed_dim, embed_dim, vb.pp("v_proj"))?;
|
||||||
|
let out_proj = linear(embed_dim, embed_dim, vb.pp("out_proj"))?;
|
||||||
|
let num_heads = cfg.num_attention_heads();
|
||||||
|
let head_dim = embed_dim / num_heads;
|
||||||
|
Ok(Self {
|
||||||
|
q_proj,
|
||||||
|
k_proj,
|
||||||
|
v_proj,
|
||||||
|
out_proj,
|
||||||
|
num_heads,
|
||||||
|
head_dim,
|
||||||
|
scale: (head_dim as f64).powf(-0.5),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
|
||||||
|
let (batch_size, q_len, _) = xs.dims3()?;
|
||||||
|
let query_states = xs.apply(&self.q_proj)?;
|
||||||
|
let key_states = xs.apply(&self.k_proj)?;
|
||||||
|
let value_states = xs.apply(&self.v_proj)?;
|
||||||
|
|
||||||
|
let shape = (batch_size, q_len, self.num_heads, self.head_dim);
|
||||||
|
let query_states = query_states.reshape(shape)?.transpose(1, 2)?.contiguous()?;
|
||||||
|
let key_states = key_states.reshape(shape)?.transpose(1, 2)?.contiguous()?;
|
||||||
|
let value_states = value_states.reshape(shape)?.transpose(1, 2)?.contiguous()?;
|
||||||
|
|
||||||
|
let attn_weights = (query_states.matmul(&key_states.t()?)? * self.scale)?;
|
||||||
|
let attn_weights = match attention_mask {
|
||||||
|
None => attn_weights,
|
||||||
|
Some(mask) => attn_weights.broadcast_add(mask)?,
|
||||||
|
};
|
||||||
|
// The original implementation upcasts to f32 but candle_nn::ops::softmax should handle this properly.
|
||||||
|
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
|
||||||
|
let attn_outputs = attn_weights
|
||||||
|
.matmul(&value_states)?
|
||||||
|
.transpose(1, 2)?
|
||||||
|
.reshape((batch_size, q_len, ()))?
|
||||||
|
.apply(&self.out_proj)?;
|
||||||
|
Ok(attn_outputs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://github.com/huggingface/transformers/blob/2e24ee4dfa39cc0bc264b89edbccc373c8337086/src/transformers/models/siglip/modeling_siglip.py#L599
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct Mlp {
|
||||||
|
fc1: Linear,
|
||||||
|
fc2: Linear,
|
||||||
|
activation_fn: candle_nn::Activation,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Mlp {
|
||||||
|
fn new<C: TransformerConfig>(cfg: &C, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let hidden_size = cfg.hidden_size();
|
||||||
|
let intermediate_size = cfg.intermediate_size();
|
||||||
|
let fc1 = candle_nn::linear(hidden_size, intermediate_size, vb.pp("fc1"))?;
|
||||||
|
let fc2 = candle_nn::linear(intermediate_size, hidden_size, vb.pp("fc2"))?;
|
||||||
|
Ok(Self {
|
||||||
|
fc1,
|
||||||
|
fc2,
|
||||||
|
activation_fn: cfg.hidden_act(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for Mlp {
|
||||||
|
fn forward(&self, xs: &candle::Tensor) -> Result<candle::Tensor> {
|
||||||
|
xs.apply(&self.fc1)?
|
||||||
|
.apply(&self.activation_fn)?
|
||||||
|
.apply(&self.fc2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://github.com/huggingface/transformers/blob/2e24ee4dfa39cc0bc264b89edbccc373c8337086/src/transformers/models/siglip/modeling_siglip.py#L614
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct EncoderLayer {
|
||||||
|
self_attn: Attention,
|
||||||
|
layer_norm1: LayerNorm,
|
||||||
|
mlp: Mlp,
|
||||||
|
layer_norm2: LayerNorm,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl EncoderLayer {
|
||||||
|
fn new<C: TransformerConfig>(cfg: &C, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let hidden_size = cfg.hidden_size();
|
||||||
|
let layer_norm_eps = cfg.layer_norm_eps();
|
||||||
|
let self_attn = Attention::new(cfg, vb.pp("self_attn"))?;
|
||||||
|
let layer_norm1 = layer_norm(hidden_size, layer_norm_eps, vb.pp("layer_norm1"))?;
|
||||||
|
let mlp = Mlp::new(cfg, vb.pp("mlp"))?;
|
||||||
|
let layer_norm2 = layer_norm(hidden_size, layer_norm_eps, vb.pp("layer_norm2"))?;
|
||||||
|
Ok(Self {
|
||||||
|
self_attn,
|
||||||
|
layer_norm1,
|
||||||
|
mlp,
|
||||||
|
layer_norm2,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
|
||||||
|
let residual = xs;
|
||||||
|
let xs = xs.apply(&self.layer_norm1)?;
|
||||||
|
let xs = self.self_attn.forward(&xs, attention_mask)?;
|
||||||
|
let xs = (residual + xs)?;
|
||||||
|
let residual = &xs;
|
||||||
|
let xs = xs.apply(&self.layer_norm2)?.apply(&self.mlp)?;
|
||||||
|
let xs = (xs + residual)?;
|
||||||
|
Ok(xs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct Encoder {
|
||||||
|
layers: Vec<EncoderLayer>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Encoder {
|
||||||
|
fn new<C: TransformerConfig>(cfg: &C, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let mut layers = vec![];
|
||||||
|
let vb = vb.pp("layers");
|
||||||
|
for layer_idx in 0..cfg.num_hidden_layers() {
|
||||||
|
let layer = EncoderLayer::new(cfg, vb.pp(layer_idx))?;
|
||||||
|
layers.push(layer)
|
||||||
|
}
|
||||||
|
Ok(Self { layers })
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
|
||||||
|
let mut xs = xs.clone();
|
||||||
|
for layer in self.layers.iter() {
|
||||||
|
xs = layer.forward(&xs, attention_mask)?
|
||||||
|
}
|
||||||
|
Ok(xs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct VisionEmbeddings {
|
||||||
|
patch_embedding: candle_nn::Conv2d,
|
||||||
|
position_embedding: candle_nn::Embedding,
|
||||||
|
position_ids: Tensor,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl VisionEmbeddings {
|
||||||
|
fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let conv2d_cfg = candle_nn::Conv2dConfig {
|
||||||
|
stride: cfg.patch_size,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let patch_embedding = candle_nn::conv2d(
|
||||||
|
cfg.num_channels,
|
||||||
|
cfg.hidden_size,
|
||||||
|
cfg.patch_size,
|
||||||
|
conv2d_cfg,
|
||||||
|
vb.pp("patch_embedding"),
|
||||||
|
)?;
|
||||||
|
let num_patches = (cfg.image_size / cfg.patch_size).pow(2);
|
||||||
|
let position_ids = Tensor::arange(0, num_patches as i64, vb.device())?;
|
||||||
|
let position_embedding =
|
||||||
|
candle_nn::embedding(num_patches, cfg.hidden_size(), vb.pp("position_embedding"))?;
|
||||||
|
Ok(Self {
|
||||||
|
patch_embedding,
|
||||||
|
position_embedding,
|
||||||
|
position_ids,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for VisionEmbeddings {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let (_batch, _channels, _height, _width) = xs.dims4()?;
|
||||||
|
let embeddings = xs.apply(&self.patch_embedding)?;
|
||||||
|
let embeddings = embeddings.flatten_from(2)?.transpose(1, 2)?;
|
||||||
|
let position_embedding = self.position_embedding.forward(&self.position_ids)?;
|
||||||
|
embeddings.broadcast_add(&position_embedding)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct VisionTransformer {
|
||||||
|
embeddings: VisionEmbeddings,
|
||||||
|
encoder: Encoder,
|
||||||
|
post_layernorm: LayerNorm,
|
||||||
|
head: Option<MultiheadAttentionPoolingHead>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl VisionTransformer {
|
||||||
|
fn new(cfg: &VisionConfig, use_head: bool, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let embeddings = VisionEmbeddings::new(cfg, vb.pp("embeddings"))?;
|
||||||
|
let encoder = Encoder::new(cfg, vb.pp("encoder"))?;
|
||||||
|
let post_layernorm =
|
||||||
|
layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("post_layernorm"))?;
|
||||||
|
let head = if use_head {
|
||||||
|
Some(MultiheadAttentionPoolingHead::new(cfg, vb.pp("head"))?)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
Ok(Self {
|
||||||
|
embeddings,
|
||||||
|
encoder,
|
||||||
|
post_layernorm,
|
||||||
|
head,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for VisionTransformer {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let xs = xs.apply(&self.embeddings)?;
|
||||||
|
let xs = self.encoder.forward(&xs, None)?;
|
||||||
|
let xs = xs.apply(&self.post_layernorm)?;
|
||||||
|
match self.head.as_ref() {
|
||||||
|
None => Ok(xs),
|
||||||
|
Some(h) => xs.apply(h),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct VisionModel {
|
||||||
|
vision_model: VisionTransformer,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl VisionModel {
|
||||||
|
pub fn new(cfg: &VisionConfig, use_head: bool, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let vision_model = VisionTransformer::new(cfg, use_head, vb)?;
|
||||||
|
Ok(Self { vision_model })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for VisionModel {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
xs.apply(&self.vision_model)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct TextEmbeddings {
|
||||||
|
token_embedding: candle_nn::Embedding,
|
||||||
|
position_embedding: candle_nn::Embedding,
|
||||||
|
position_ids: Tensor,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TextEmbeddings {
|
||||||
|
fn new(cfg: &TextConfig, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let token_embedding =
|
||||||
|
candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("token_embedding"))?;
|
||||||
|
let position_embedding = candle_nn::embedding(
|
||||||
|
cfg.max_position_embeddings,
|
||||||
|
cfg.hidden_size,
|
||||||
|
vb.pp("position_embedding"),
|
||||||
|
)?;
|
||||||
|
let position_ids =
|
||||||
|
Tensor::arange(0u32, cfg.max_position_embeddings as u32, vb.device())?.unsqueeze(0)?;
|
||||||
|
Ok(Self {
|
||||||
|
token_embedding,
|
||||||
|
position_embedding,
|
||||||
|
position_ids,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for TextEmbeddings {
|
||||||
|
fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
|
||||||
|
let seq_length = input_ids.dim(D::Minus1)?;
|
||||||
|
let inputs_embeds = self.token_embedding.forward(input_ids)?;
|
||||||
|
let position_ids = self.position_ids.narrow(1, 0, seq_length)?;
|
||||||
|
let position_embedding = self.position_embedding.forward(&position_ids)?;
|
||||||
|
inputs_embeds.broadcast_add(&position_embedding)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct TextTransformer {
|
||||||
|
embeddings: TextEmbeddings,
|
||||||
|
encoder: Encoder,
|
||||||
|
final_layer_norm: LayerNorm,
|
||||||
|
pub head: Linear,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TextTransformer {
|
||||||
|
fn new(cfg: &TextConfig, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let embeddings = TextEmbeddings::new(cfg, vb.pp("embeddings"))?;
|
||||||
|
let encoder = Encoder::new(cfg, vb.pp("encoder"))?;
|
||||||
|
let final_layer_norm = layer_norm(
|
||||||
|
cfg.hidden_size,
|
||||||
|
cfg.layer_norm_eps,
|
||||||
|
vb.pp("final_layer_norm"),
|
||||||
|
)?;
|
||||||
|
let head = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("head"))?;
|
||||||
|
Ok(Self {
|
||||||
|
embeddings,
|
||||||
|
encoder,
|
||||||
|
final_layer_norm,
|
||||||
|
head,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
impl Module for TextTransformer {
|
||||||
|
fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
|
||||||
|
let (_bsz, seq_len) = input_ids.dims2()?;
|
||||||
|
let input_ids = self.embeddings.forward(input_ids)?;
|
||||||
|
let input_ids = self.encoder.forward(&input_ids, None)?;
|
||||||
|
let last_hidden_state = self.final_layer_norm.forward(&input_ids)?;
|
||||||
|
last_hidden_state
|
||||||
|
.i((.., seq_len - 1, ..))?
|
||||||
|
.contiguous()?
|
||||||
|
.apply(&self.head)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct TextModel {
|
||||||
|
pub text_model: TextTransformer,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TextModel {
|
||||||
|
pub fn new(cfg: &TextConfig, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let text_model = TextTransformer::new(cfg, vb)?;
|
||||||
|
Ok(Self { text_model })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for TextModel {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
xs.apply(&self.text_model)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct Model {
|
||||||
|
text_model: TextModel,
|
||||||
|
vision_model: VisionModel,
|
||||||
|
logit_bias: Tensor,
|
||||||
|
logit_scale: Tensor,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Model {
|
||||||
|
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let text_model = TextModel::new(&cfg.text_config, vb.pp("text_model"))?;
|
||||||
|
let vision_model = VisionModel::new(&cfg.vision_config, true, vb.pp("vision_model"))?;
|
||||||
|
let logit_scale = vb.get(&[1], "logit_scale")?;
|
||||||
|
let logit_bias = vb.get(&[1], "logit_bias")?;
|
||||||
|
Ok(Self {
|
||||||
|
text_model,
|
||||||
|
vision_model,
|
||||||
|
logit_bias,
|
||||||
|
logit_scale,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_text_features(&self, input_ids: &Tensor) -> Result<Tensor> {
|
||||||
|
input_ids.apply(&self.text_model)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_image_features(&self, pixel_values: &Tensor) -> Result<Tensor> {
|
||||||
|
pixel_values.apply(&self.vision_model)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(&self, pixel_values: &Tensor, input_ids: &Tensor) -> Result<(Tensor, Tensor)> {
|
||||||
|
let image_features = self.get_image_features(pixel_values)?;
|
||||||
|
let text_features = self.get_text_features(input_ids)?;
|
||||||
|
let image_features_normalized = div_l2_norm(&image_features)?;
|
||||||
|
let text_features_normalized = div_l2_norm(&text_features)?;
|
||||||
|
let logits_per_text = text_features_normalized.matmul(&image_features_normalized.t()?)?;
|
||||||
|
let logit_scale = self.logit_scale.exp()?;
|
||||||
|
let logits_per_text = logits_per_text
|
||||||
|
.broadcast_mul(&logit_scale)?
|
||||||
|
.broadcast_add(&self.logit_bias)?;
|
||||||
|
let logits_per_image = logits_per_text.t()?;
|
||||||
|
Ok((logits_per_text, logits_per_image))
|
||||||
|
}
|
||||||
|
}
|
Reference in New Issue
Block a user