diff --git a/candle-examples/examples/dinov2reg4/README.md b/candle-examples/examples/dinov2reg4/README.md new file mode 100644 index 00000000..ac86ca69 --- /dev/null +++ b/candle-examples/examples/dinov2reg4/README.md @@ -0,0 +1,25 @@ +# candle-dinov2-reg4 + +[DINOv2-reg4](https://arxiv.org/abs/2309.16588) is the lastest version of DINOv2 with registers. +In this example, it is used as an plant species classifier: the model returns the +probability for the image to belong to each of the 7806 PlantCLEF2024 categories. + +## Running some example + +```bash +# Download classes names and a plant picture to identify +curl https://huggingface.co/vincent-espitalier/dino-v2-reg4-with-plantclef2024-weights/raw/main/species_id_mapping.txt --output candle-examples/examples/dinov2reg4/species_id_mapping.txt +curl https://bs.plantnet.org/image/o/bd2d3830ac3270218ba82fd24e2290becd01317c --output candle-examples/examples/dinov2reg4/bd2d3830ac3270218ba82fd24e2290becd01317c.jpg + +# Perform inference +cargo run --example dinov2reg4 --release -- --image candle-examples/examples/dinov2reg4/bd2d3830ac3270218ba82fd24e2290becd01317c.jpg + +> Orchis simia Lam. : 45.55% +> Orchis × bergonii Nanteuil: 9.80% +> Orchis italica Poir. : 9.66% +> Orchis × angusticruris Franch.: 2.76% +> Orchis × bivonae Tod. : 2.54% + +``` + +![Orchis Simia](https://bs.plantnet.org/image/o/bd2d3830ac3270218ba82fd24e2290becd01317c) diff --git a/candle-examples/examples/dinov2reg4/main.rs b/candle-examples/examples/dinov2reg4/main.rs new file mode 100644 index 00000000..15270517 --- /dev/null +++ b/candle-examples/examples/dinov2reg4/main.rs @@ -0,0 +1,70 @@ +//! DINOv2 reg4 finetuned on PlantCLEF 2024 +//! https://arxiv.org/abs/2309.16588 +//! https://huggingface.co/spaces/BVRA/PlantCLEF2024 +//! https://zenodo.org/records/10848263 + +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use clap::Parser; + +use candle::{DType, IndexOp, D}; +use candle_nn::{Module, VarBuilder}; +use candle_transformers::models::dinov2reg4; + +#[derive(Parser)] +struct Args { + #[arg(long)] + model: Option, + + #[arg(long)] + image: String, + + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, +} + +pub fn main() -> anyhow::Result<()> { + let args = Args::parse(); + + let device = candle_examples::device(args.cpu)?; + + let image = candle_examples::imagenet::load_image518(args.image)?.to_device(&device)?; + println!("loaded image {image:?}"); + + let f_species_id_mapping = "candle-examples/examples/dinov2reg4/species_id_mapping.txt"; + let classes: Vec = std::fs::read_to_string(f_species_id_mapping) + .expect("missing classes file") + .split('\n') + .map(|s| s.to_string()) + .collect(); + + let model_file = match args.model { + None => { + let api = hf_hub::api::sync::Api::new()?; + let api = + api.model("vincent-espitalier/dino-v2-reg4-with-plantclef2024-weights".into()); + api.get( + "vit_base_patch14_reg4_dinov2_lvd142m_pc24_onlyclassifier_then_all.safetensors", + )? + } + Some(model) => model.into(), + }; + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? }; + let model = dinov2reg4::vit_base(vb)?; + println!("model built"); + let logits = model.forward(&image.unsqueeze(0)?)?; + let prs = candle_nn::ops::softmax(&logits, D::Minus1)? + .i(0)? + .to_vec1::()?; + let mut prs = prs.iter().enumerate().collect::>(); + prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1)); + for &(category_idx, pr) in prs.iter().take(5) { + println!("{:24}: {:.2}%", classes[category_idx], 100. * pr); + } + Ok(()) +} diff --git a/candle-examples/src/imagenet.rs b/candle-examples/src/imagenet.rs index cefbd71b..781dcd4f 100644 --- a/candle-examples/src/imagenet.rs +++ b/candle-examples/src/imagenet.rs @@ -17,6 +17,24 @@ pub fn load_image224>(p: P) -> Result { .broadcast_div(&std) } +/// Loads an image from disk using the image crate, this returns a tensor with shape +/// (3, 518, 518). imagenet normalization is applied. +/// The model dinov2 reg4 analyzes images with dimensions 3x518x518 (resulting in 37x37 transformer tokens). +pub fn load_image518>(p: P) -> Result { + let img = image::io::Reader::open(p)? + .decode() + .map_err(candle::Error::wrap)? + .resize_to_fill(518, 518, image::imageops::FilterType::Triangle); + let img = img.to_rgb8(); + let data = img.into_raw(); + let data = Tensor::from_vec(data, (518, 518, 3), &Device::Cpu)?.permute((2, 0, 1))?; + let mean = Tensor::new(&[0.485f32, 0.456, 0.406], &Device::Cpu)?.reshape((3, 1, 1))?; + let std = Tensor::new(&[0.229f32, 0.224, 0.225], &Device::Cpu)?.reshape((3, 1, 1))?; + (data.to_dtype(candle::DType::F32)? / 255.)? + .broadcast_sub(&mean)? + .broadcast_div(&std) +} + pub const CLASS_COUNT: i64 = 1000; pub const CLASSES: [&str; 1000] = [ diff --git a/candle-transformers/src/models/dinov2reg4.rs b/candle-transformers/src/models/dinov2reg4.rs new file mode 100644 index 00000000..6bbe2e24 --- /dev/null +++ b/candle-transformers/src/models/dinov2reg4.rs @@ -0,0 +1,281 @@ +use candle::{IndexOp, Result, Tensor, D}; +use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder}; + +const IMG_SIZE: usize = 518; +const PATCH_SIZE: usize = 14; +const NUM_CLASSES: usize = 7806; // PlantCLEF2024 DINOv2 (https://zenodo.org/records/10848263) + +fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result { + if bias { + candle_nn::linear(in_dim, out_dim, vb) + } else { + candle_nn::linear_no_bias(in_dim, out_dim, vb) + } +} + +#[derive(Debug)] +struct Attention { + qkv: Linear, + proj: Linear, + num_heads: usize, + scale: f64, +} + +impl Attention { + fn new( + vb: VarBuilder, + dim: usize, + num_heads: usize, + qkv_bias: bool, + proj_bias: bool, + ) -> Result { + let qkv = linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?; + let proj = linear(vb.pp("proj"), dim, dim, proj_bias)?; + let scale = 1. / ((dim / num_heads) as f64).sqrt(); + Ok(Self { + qkv, + proj, + num_heads, + scale, + }) + } +} + +impl Module for Attention { + fn forward(&self, xs: &Tensor) -> Result { + let (b, n, c) = xs.dims3()?; + let qkv = self + .qkv + .forward(xs)? + .reshape((b, n, 3, self.num_heads, c / self.num_heads))? + .transpose(1, 2)? // 02134 + .transpose(0, 1)? // 20134 + .transpose(2, 3)?; // 20314 + let q = (qkv.i(0)? * self.scale)?; + let k = qkv.i(1)?.contiguous()?; + let v = qkv.i(2)?.contiguous()?; + let attn = candle_nn::ops::softmax(&q.matmul(&k.t()?)?, D::Minus1)?; + let attn = attn.matmul(&v)?.transpose(1, 2)?.reshape((b, n, c))?; + self.proj.forward(&attn) + } +} + +#[derive(Debug)] +struct LayerScale { + gamma: Tensor, +} + +impl LayerScale { + fn new(vb: VarBuilder, dim: usize) -> Result { + let gamma = vb.get(dim, "gamma")?; + Ok(Self { gamma }) + } +} + +impl Module for LayerScale { + fn forward(&self, xs: &Tensor) -> Result { + xs.broadcast_mul(&self.gamma) + } +} + +#[derive(Debug)] +struct Mlp { + fc1: Linear, + fc2: Linear, +} + +impl Mlp { + fn new(vb: VarBuilder, in_features: usize, hidden_features: usize, bias: bool) -> Result { + let out_features = in_features; + let fc1 = linear(vb.pp("fc1"), in_features, hidden_features, bias)?; + let fc2 = linear(vb.pp("fc2"), hidden_features, out_features, bias)?; + Ok(Self { fc1, fc2 }) + } +} + +impl Module for Mlp { + fn forward(&self, xs: &Tensor) -> Result { + let xs = self.fc1.forward(xs)?.gelu()?; + self.fc2.forward(&xs) + } +} + +#[derive(Debug)] +struct Block { + norm1: LayerNorm, + attn: Attention, + ls1: LayerScale, + norm2: LayerNorm, + mlp: Mlp, + ls2: LayerScale, +} + +impl Block { + fn new(vb: VarBuilder, dim: usize, num_heads: usize) -> Result { + let norm1 = layer_norm(dim, 1e-6, vb.pp("norm1"))?; + let attn = Attention::new(vb.pp("attn"), dim, num_heads, true, true)?; + let ls1 = LayerScale::new(vb.pp("ls1"), dim)?; + let norm2 = layer_norm(dim, 1e-6, vb.pp("norm2"))?; + let mlp = Mlp::new(vb.pp("mlp"), dim, dim * 4, true)?; + let ls2 = LayerScale::new(vb.pp("ls2"), dim)?; + Ok(Self { + norm1, + attn, + ls1, + norm2, + mlp, + ls2, + }) + } +} + +impl Module for Block { + fn forward(&self, xs: &Tensor) -> Result { + let residual = xs; + let xs = self + .ls1 + .forward(&self.attn.forward(&self.norm1.forward(xs)?)?)?; + let xs = (xs + residual)?; + let residual = &xs; + let xs = self + .ls2 + .forward(&self.mlp.forward(&self.norm2.forward(&xs)?)?)?; + xs + residual + } +} + +#[derive(Debug)] +struct PatchEmbed { + proj: candle_nn::Conv2d, + patch_size: (usize, usize), + num_patches: usize, +} + +impl PatchEmbed { + fn new( + vb: VarBuilder, + img_size: usize, + patch_size: usize, + in_chans: usize, + embed_dim: usize, + ) -> Result { + let config = candle_nn::Conv2dConfig { + stride: patch_size, + ..Default::default() + }; + let proj = candle_nn::conv2d(in_chans, embed_dim, patch_size, config, vb.pp("proj"))?; + let num_patches = (img_size / patch_size) * (img_size / patch_size); + Ok(Self { + proj, + patch_size: (patch_size, patch_size), + num_patches, + }) + } +} + +impl Module for PatchEmbed { + fn forward(&self, xs: &Tensor) -> Result { + let (_b, _c, h, w) = xs.dims4()?; + let (patch_h, patch_w) = self.patch_size; + if (h % patch_h) != 0 { + candle::bail!("image height {h} is not a multiple of patch height {patch_h}") + } + if (w % patch_w) != 0 { + candle::bail!("image width {w} is not a multiple of patch width {patch_w}") + } + let xs = self.proj.forward(xs)?; + let (b, c, h, w) = xs.dims4()?; + // flatten embeddings. + xs.reshape((b, c, h * w))?.transpose(1, 2) + } +} + +#[derive(Debug)] +pub struct DinoVisionTransformer { + patch_embed: PatchEmbed, + cls_token: Tensor, + reg_token: Tensor, + pos_embed: Tensor, + blocks: Vec, + norm: LayerNorm, + head: Linear, +} + +impl DinoVisionTransformer { + pub fn new(vb: VarBuilder, depth: usize, embed_dim: usize, num_heads: usize) -> Result { + let patch_embed = + PatchEmbed::new(vb.pp("patch_embed"), IMG_SIZE, PATCH_SIZE, 3, embed_dim)?; + let cls_token = vb.get((1, 1, embed_dim), "cls_token")?; + let reg_token = vb.get((1, 4, embed_dim), "reg_token")?; + let pos_embed = vb.get((1, patch_embed.num_patches, embed_dim), "pos_embed")?; + let head = linear(vb.pp("head"), embed_dim, NUM_CLASSES, true)?; + let norm = layer_norm(embed_dim, 1e-6, vb.pp("norm"))?; + let vb_b = vb.pp("blocks"); + let blocks = (0..depth) + .map(|i| Block::new(vb_b.pp(&i.to_string()), embed_dim, num_heads)) + .collect::>>()?; + Ok(Self { + patch_embed, + cls_token, + reg_token, + pos_embed, + blocks, + norm, + head, + }) + } + + fn interpolate_pos_encoding(&self, xs: &Tensor, w: usize, h: usize) -> Result { + let npatch = xs.dim(1)? - 1; + let n = self.pos_embed.dim(1)? - 1; + let sqrt_n = (n as f64).sqrt(); + if npatch == n && w == h { + return Ok(self.pos_embed.clone()); + } + let patch_pos_embed = &self.pos_embed; + let dim = xs.dim(D::Minus1)?; + let (w0, h0) = ((w / PATCH_SIZE) as f64 + 0.1, (h / PATCH_SIZE) as f64 + 0.1); + let patch_pos_embed = patch_pos_embed + .reshape((1, sqrt_n as usize, sqrt_n as usize, dim))? + .transpose(2, 3)? + .transpose(1, 2)?; + // This uses bicubic interpolation in the original implementation. + let patch_pos_embed = patch_pos_embed.upsample_nearest2d(h0 as usize, w0 as usize)?; + let el_count = patch_pos_embed.shape().elem_count(); + patch_pos_embed + .transpose(1, 2)? + .transpose(2, 3)? + .reshape((1, el_count / dim, dim)) + } + + fn prepare_tokens_with_mask(&self, xs: &Tensor) -> Result { + let (_b, _nc, w, h) = xs.dims4()?; + if (w != IMG_SIZE) || (h != IMG_SIZE) { + panic!("Error: The input tensor should have the shape: Bx3x518x518."); + } + let xs = self.patch_embed.forward(xs)?; + let xs = (&xs + &self.interpolate_pos_encoding(&xs, w, h)?)?; + let xs = Tensor::cat(&[&self.cls_token, &self.reg_token, &xs], 1)?; + Ok(xs) + } +} + +impl Module for DinoVisionTransformer { + fn forward(&self, xs: &Tensor) -> Result { + let mut xs = self.prepare_tokens_with_mask(xs)?; + for blk in self.blocks.iter() { + xs = blk.forward(&xs)? + } + let xs = self.norm.forward(&xs)?; + let xs_norm_clstoken = xs.i((.., 0))?; + self.head.forward(&xs_norm_clstoken) + } +} + +pub fn vit_small(vb: VarBuilder) -> Result { + DinoVisionTransformer::new(vb, 12, 384, 6) +} + +pub fn vit_base(vb: VarBuilder) -> Result { + DinoVisionTransformer::new(vb, 12, 768, 12) +} diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index 89ae0f8a..2908d345 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -8,6 +8,7 @@ pub mod convmixer; pub mod convnext; pub mod depth_anything_v2; pub mod dinov2; +pub mod dinov2reg4; pub mod distilbert; pub mod efficientnet; pub mod efficientvit;