mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Add some vision transformers models (#1132)
* Start adding vision-transformers. * Add self-attn. * More vision transformers. * vit-vit. * Add the actual vit model. * Add the example code for the vision transformers.
This commit is contained in:
59
candle-examples/examples/vit/main.rs
Normal file
59
candle-examples/examples/vit/main.rs
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
|
||||||
|
use clap::Parser;
|
||||||
|
|
||||||
|
use candle::{DType, IndexOp, D};
|
||||||
|
use candle_nn::VarBuilder;
|
||||||
|
use candle_transformers::models::vit;
|
||||||
|
|
||||||
|
#[derive(Parser)]
|
||||||
|
struct Args {
|
||||||
|
#[arg(long)]
|
||||||
|
model: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
image: String,
|
||||||
|
|
||||||
|
/// Run on CPU rather than on GPU.
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn main() -> anyhow::Result<()> {
|
||||||
|
let args = Args::parse();
|
||||||
|
|
||||||
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
|
||||||
|
let image = candle_examples::imagenet::load_image224(args.image)?;
|
||||||
|
println!("loaded image {image:?}");
|
||||||
|
|
||||||
|
let model_file = match args.model {
|
||||||
|
None => {
|
||||||
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
|
let api = api.model("google/vit-base-patch16-224".into());
|
||||||
|
api.get("model.safetensors")?
|
||||||
|
}
|
||||||
|
Some(model) => model.into(),
|
||||||
|
};
|
||||||
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
|
||||||
|
let model = vit::Model::new(&vit::Config::vit_base_patch16_224(), 1000, vb)?;
|
||||||
|
println!("model built");
|
||||||
|
let logits = model.forward(&image.unsqueeze(0)?)?;
|
||||||
|
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
|
||||||
|
.i(0)?
|
||||||
|
.to_vec1::<f32>()?;
|
||||||
|
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
|
||||||
|
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
|
||||||
|
for &(category_idx, pr) in prs.iter().take(5) {
|
||||||
|
println!(
|
||||||
|
"{:24}: {:.2}%",
|
||||||
|
candle_examples::imagenet::CLASSES[category_idx],
|
||||||
|
100. * pr
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
@ -19,6 +19,7 @@ pub mod segment_anything;
|
|||||||
pub mod stable_diffusion;
|
pub mod stable_diffusion;
|
||||||
pub mod stable_lm;
|
pub mod stable_lm;
|
||||||
pub mod t5;
|
pub mod t5;
|
||||||
|
pub mod vit;
|
||||||
pub mod whisper;
|
pub mod whisper;
|
||||||
pub mod with_tracing;
|
pub mod with_tracing;
|
||||||
pub mod wuerstchen;
|
pub mod wuerstchen;
|
||||||
|
382
candle-transformers/src/models/vit.rs
Normal file
382
candle-transformers/src/models/vit.rs
Normal file
@ -0,0 +1,382 @@
|
|||||||
|
#![allow(unused)]
|
||||||
|
use crate::models::with_tracing::{conv2d, linear, linear_no_bias, Conv2d, Linear};
|
||||||
|
use candle::{IndexOp, Module, Result, Tensor, D};
|
||||||
|
use candle_nn::{layer_norm, LayerNorm, VarBuilder};
|
||||||
|
|
||||||
|
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/vit/configuration_vit.py
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct Config {
|
||||||
|
hidden_size: usize,
|
||||||
|
num_hidden_layers: usize,
|
||||||
|
num_attention_heads: usize,
|
||||||
|
intermediate_size: usize,
|
||||||
|
hidden_act: candle_nn::Activation,
|
||||||
|
layer_norm_eps: f64,
|
||||||
|
image_size: usize,
|
||||||
|
patch_size: usize,
|
||||||
|
num_channels: usize,
|
||||||
|
qkv_bias: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Config {
|
||||||
|
// https://huggingface.co/google/vit-base-patch16-224/blob/main/config.json
|
||||||
|
pub fn vit_base_patch16_224() -> Self {
|
||||||
|
Self {
|
||||||
|
hidden_size: 768,
|
||||||
|
num_hidden_layers: 12,
|
||||||
|
num_attention_heads: 12,
|
||||||
|
intermediate_size: 3072,
|
||||||
|
hidden_act: candle_nn::Activation::Gelu,
|
||||||
|
layer_norm_eps: 1e-12,
|
||||||
|
image_size: 224,
|
||||||
|
patch_size: 16,
|
||||||
|
num_channels: 3,
|
||||||
|
qkv_bias: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct PatchEmbeddings {
|
||||||
|
num_patches: usize,
|
||||||
|
projection: Conv2d,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PatchEmbeddings {
|
||||||
|
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let image_size = cfg.image_size;
|
||||||
|
let patch_size = cfg.patch_size;
|
||||||
|
let num_patches = (image_size / patch_size) * (image_size / patch_size);
|
||||||
|
let conv_cfg = candle_nn::Conv2dConfig {
|
||||||
|
stride: patch_size,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let projection = conv2d(
|
||||||
|
cfg.num_channels,
|
||||||
|
cfg.hidden_size,
|
||||||
|
patch_size,
|
||||||
|
conv_cfg,
|
||||||
|
vb.pp("projection"),
|
||||||
|
)?;
|
||||||
|
Ok(Self {
|
||||||
|
num_patches,
|
||||||
|
projection,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for PatchEmbeddings {
|
||||||
|
fn forward(&self, pixel_values: &Tensor) -> Result<Tensor> {
|
||||||
|
let (b_size, num_channels, height, width) = pixel_values.dims4()?;
|
||||||
|
self.projection
|
||||||
|
.forward(pixel_values)?
|
||||||
|
.flatten_from(2)?
|
||||||
|
.transpose(1, 2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct Embeddings {
|
||||||
|
cls_token: Tensor,
|
||||||
|
mask_token: Option<Tensor>,
|
||||||
|
patch_embeddings: PatchEmbeddings,
|
||||||
|
position_embeddings: Tensor,
|
||||||
|
hidden_size: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Embeddings {
|
||||||
|
fn new(cfg: &Config, use_mask_token: bool, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let hidden_size = cfg.hidden_size;
|
||||||
|
let cls_token = vb.get((1, 1, hidden_size), "cls_token")?;
|
||||||
|
let mask_token = if use_mask_token {
|
||||||
|
Some(vb.get((1, 1, hidden_size), "mask_token")?)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
let patch_embeddings = PatchEmbeddings::new(cfg, vb.pp("patch_embeddings"))?;
|
||||||
|
let num_patches = patch_embeddings.num_patches;
|
||||||
|
let position_embeddings =
|
||||||
|
vb.get((1, num_patches + 1, hidden_size), "position_embeddings")?;
|
||||||
|
Ok(Self {
|
||||||
|
cls_token,
|
||||||
|
mask_token,
|
||||||
|
patch_embeddings,
|
||||||
|
position_embeddings,
|
||||||
|
hidden_size,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn interpolate_pos_encoding(
|
||||||
|
&self,
|
||||||
|
embeddings: &Tensor,
|
||||||
|
height: usize,
|
||||||
|
width: usize,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(
|
||||||
|
&self,
|
||||||
|
pixel_values: &Tensor,
|
||||||
|
bool_masked_pos: Option<&Tensor>,
|
||||||
|
interpolate_pos_encoding: bool,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
let (b_size, num_channels, height, width) = pixel_values.dims4()?;
|
||||||
|
let embeddings = self.patch_embeddings.forward(pixel_values)?;
|
||||||
|
let embeddings = match (bool_masked_pos, &self.mask_token) {
|
||||||
|
(None, _) => embeddings,
|
||||||
|
(Some(_), None) => candle::bail!("bool_masked_pos set without mask_token"),
|
||||||
|
(Some(bool_masked_pos), Some(mask_tokens)) => {
|
||||||
|
let seq_len = embeddings.dim(1)?;
|
||||||
|
let mask_tokens = mask_tokens.broadcast_as((b_size, seq_len, self.hidden_size))?;
|
||||||
|
let mask = bool_masked_pos
|
||||||
|
.unsqueeze(D::Minus1)?
|
||||||
|
.to_dtype(mask_tokens.dtype())?;
|
||||||
|
((mask_tokens * &mask)? - (embeddings * (mask - 1.)?)?)?
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let cls_tokens = self.cls_token.broadcast_as((b_size, 1, self.hidden_size))?;
|
||||||
|
let embeddings = Tensor::cat(&[&cls_tokens, &embeddings], 1)?;
|
||||||
|
if interpolate_pos_encoding {
|
||||||
|
let pos = self.interpolate_pos_encoding(&embeddings, height, width)?;
|
||||||
|
embeddings.broadcast_add(&pos)
|
||||||
|
} else {
|
||||||
|
embeddings.broadcast_add(&self.position_embeddings)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct SelfAttention {
|
||||||
|
query: Linear,
|
||||||
|
key: Linear,
|
||||||
|
value: Linear,
|
||||||
|
num_attention_heads: usize,
|
||||||
|
attention_head_size: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SelfAttention {
|
||||||
|
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let attention_head_size = cfg.hidden_size / cfg.num_attention_heads;
|
||||||
|
let num_attention_heads = cfg.num_attention_heads;
|
||||||
|
let all_head_size = num_attention_heads * attention_head_size;
|
||||||
|
let linear = |name| {
|
||||||
|
if cfg.qkv_bias {
|
||||||
|
linear(cfg.hidden_size, all_head_size, vb.pp(name))
|
||||||
|
} else {
|
||||||
|
linear_no_bias(cfg.hidden_size, all_head_size, vb.pp(name))
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let query = linear("query")?;
|
||||||
|
let key = linear("key")?;
|
||||||
|
let value = linear("value")?;
|
||||||
|
Ok(Self {
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
num_attention_heads,
|
||||||
|
attention_head_size,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn transpose_for_scores(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let (b_size, seq_len, _) = xs.dims3()?;
|
||||||
|
xs.reshape((
|
||||||
|
b_size,
|
||||||
|
seq_len,
|
||||||
|
self.num_attention_heads,
|
||||||
|
self.attention_head_size,
|
||||||
|
))?
|
||||||
|
.permute((0, 2, 1, 3))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for SelfAttention {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let query = self.query.forward(xs)?;
|
||||||
|
let key = self.key.forward(xs)?;
|
||||||
|
let value = self.value.forward(xs)?;
|
||||||
|
|
||||||
|
let query = self.transpose_for_scores(&query)?.contiguous()?;
|
||||||
|
let key = self.transpose_for_scores(&key)?.contiguous()?;
|
||||||
|
let value = self.transpose_for_scores(&value)?.contiguous()?;
|
||||||
|
|
||||||
|
let attention_scores =
|
||||||
|
(query.matmul(&key.t()?)? / f64::sqrt(self.attention_head_size as f64))?;
|
||||||
|
let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?;
|
||||||
|
attention_probs
|
||||||
|
.matmul(&value)?
|
||||||
|
.permute((0, 2, 1, 3))?
|
||||||
|
.contiguous()?
|
||||||
|
.flatten_from(D::Minus2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct SelfOutput {
|
||||||
|
dense: Linear,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SelfOutput {
|
||||||
|
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let dense = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("dense"))?;
|
||||||
|
Ok(Self { dense })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for SelfOutput {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
xs.apply(&self.dense)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct Attention {
|
||||||
|
attention: SelfAttention,
|
||||||
|
output: SelfOutput,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Attention {
|
||||||
|
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let attention = SelfAttention::new(cfg, vb.pp("attention"))?;
|
||||||
|
let output = SelfOutput::new(cfg, vb.pp("output"))?;
|
||||||
|
Ok(Self { attention, output })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for Attention {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
xs.apply(&self.attention)?.apply(&self.output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct Intermediate {
|
||||||
|
dense: Linear,
|
||||||
|
intermediate_act_fn: candle_nn::Activation,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Intermediate {
|
||||||
|
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let dense = linear(cfg.hidden_size, cfg.intermediate_size, vb.pp("dense"))?;
|
||||||
|
Ok(Self {
|
||||||
|
dense,
|
||||||
|
intermediate_act_fn: cfg.hidden_act,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for Intermediate {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
xs.apply(&self.dense)?.apply(&self.intermediate_act_fn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct Output {
|
||||||
|
dense: Linear,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Output {
|
||||||
|
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let dense = linear(cfg.intermediate_size, cfg.hidden_size, vb.pp("dense"))?;
|
||||||
|
Ok(Self { dense })
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(&self, xs: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
|
||||||
|
xs.apply(&self.dense)? + input_tensor
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct Layer {
|
||||||
|
attention: Attention,
|
||||||
|
intermediate: Intermediate,
|
||||||
|
output: Output,
|
||||||
|
layernorm_before: LayerNorm,
|
||||||
|
layernorm_after: LayerNorm,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Layer {
|
||||||
|
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let attention = Attention::new(cfg, vb.pp("attention"))?;
|
||||||
|
let intermediate = Intermediate::new(cfg, vb.pp("intermediate"))?;
|
||||||
|
let output = Output::new(cfg, vb.pp("output"))?;
|
||||||
|
let h_sz = cfg.hidden_size;
|
||||||
|
let layernorm_before = layer_norm(h_sz, cfg.layer_norm_eps, vb.pp("layernorm_before"))?;
|
||||||
|
let layernorm_after = layer_norm(h_sz, cfg.layer_norm_eps, vb.pp("layernorm_after"))?;
|
||||||
|
Ok(Self {
|
||||||
|
attention,
|
||||||
|
intermediate,
|
||||||
|
output,
|
||||||
|
layernorm_after,
|
||||||
|
layernorm_before,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for Layer {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let xs = (xs.apply(&self.layernorm_before)?.apply(&self.attention)? + xs)?;
|
||||||
|
let ys = xs.apply(&self.layernorm_after)?.apply(&self.intermediate)?;
|
||||||
|
self.output.forward(&ys, &xs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct Encoder {
|
||||||
|
layers: Vec<Layer>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Encoder {
|
||||||
|
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let vb = vb.pp("layer");
|
||||||
|
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
|
||||||
|
for i in 0..cfg.num_hidden_layers {
|
||||||
|
let layer = Layer::new(cfg, vb.pp(i))?;
|
||||||
|
layers.push(layer)
|
||||||
|
}
|
||||||
|
Ok(Self { layers })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for Encoder {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let mut xs = xs.clone();
|
||||||
|
for layer in self.layers.iter() {
|
||||||
|
xs = xs.apply(layer)?
|
||||||
|
}
|
||||||
|
Ok(xs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct Model {
|
||||||
|
embeddings: Embeddings,
|
||||||
|
encoder: Encoder,
|
||||||
|
layernorm: LayerNorm,
|
||||||
|
// no need for pooling layer for image classification
|
||||||
|
classifier: Linear,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Model {
|
||||||
|
pub fn new(cfg: &Config, num_labels: usize, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let vb_v = vb.pp("vit");
|
||||||
|
let embeddings = Embeddings::new(cfg, false, vb_v.pp("embeddings"))?;
|
||||||
|
let encoder = Encoder::new(cfg, vb_v.pp("encoder"))?;
|
||||||
|
let layernorm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb_v.pp("layernorm"))?;
|
||||||
|
let classifier = linear(cfg.hidden_size, num_labels, vb.pp("classifier"))?;
|
||||||
|
Ok(Self {
|
||||||
|
embeddings,
|
||||||
|
encoder,
|
||||||
|
layernorm,
|
||||||
|
classifier,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let embedding_output = self.embeddings.forward(xs, None, false)?;
|
||||||
|
let encoder_outputs = self.encoder.forward(&embedding_output)?;
|
||||||
|
encoder_outputs.i((.., 0, ..))?.apply(&self.classifier)
|
||||||
|
}
|
||||||
|
}
|
Reference in New Issue
Block a user