From ecff05d72b9bbbdf1d0d07f06ebf5b7b01bfe3db Mon Sep 17 00:00:00 2001 From: v-espitalier <125037408+v-espitalier@users.noreply.github.com> Date: Thu, 4 Jul 2024 09:45:26 +0200 Subject: [PATCH] Beit: Add the gen_relative_position_index() function (#2306) Co-authored-by: v-espitalier <> --- candle-examples/examples/beit/main.rs | 2 +- candle-transformers/src/models/beit.rs | 89 ++++++++++++++++++-------- 2 files changed, 64 insertions(+), 27 deletions(-) diff --git a/candle-examples/examples/beit/main.rs b/candle-examples/examples/beit/main.rs index 5ef2a6ae..a256fd99 100644 --- a/candle-examples/examples/beit/main.rs +++ b/candle-examples/examples/beit/main.rs @@ -55,7 +55,7 @@ pub fn main() -> anyhow::Result<()> { None => { let api = hf_hub::api::sync::Api::new()?; let api = api.model("vincent-espitalier/candle-beit".into()); - api.get("beit_base_patch16_384.in22k_ft_in22k_in1k_adapted.safetensors")? + api.get("beit_base_patch16_384.in22k_ft_in22k_in1k.safetensors")? } Some(model) => model.into(), }; diff --git a/candle-transformers/src/models/beit.rs b/candle-transformers/src/models/beit.rs index c534032c..62bdd75a 100644 --- a/candle-transformers/src/models/beit.rs +++ b/candle-transformers/src/models/beit.rs @@ -1,4 +1,4 @@ -use candle::{DType, IndexOp, Result, Tensor, D}; +use candle::{DType, Device, IndexOp, Result, Tensor, D}; use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder}; const IMG_SIZE: usize = 384; @@ -32,7 +32,6 @@ impl Attention { num_heads: usize, qkv_bias: bool, proj_bias: bool, - relative_position_index: &Tensor, ) -> Result { let qkv = linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?; let proj = linear(vb.pp("proj"), dim, dim, proj_bias)?; @@ -42,7 +41,8 @@ impl Attention { (num_relative_distance, num_heads), "relative_position_bias_table", )?; - let relative_position_index = relative_position_index.clone(); + let relative_position_index = + Self::gen_relative_position_index(relative_position_bias_table.device())?; let scale = 1. / ((dim / num_heads) as f64).sqrt(); Ok(Self { qkv, @@ -56,6 +56,63 @@ impl Attention { } impl Attention { + // See: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/beit.py#L61 + fn gen_relative_position_index(device: &Device) -> Result { + let num_relative_distance = (2 * WINDOW_SIZE - 1) * (2 * WINDOW_SIZE - 1) + 3; + let w_area = WINDOW_SIZE * WINDOW_SIZE; + + let t_arange: Tensor = Tensor::arange(0, WINDOW_SIZE as u32, device)?; + let t_ndgrid = Tensor::meshgrid(&[&t_arange, &t_arange], false)?; + let coords_flatten = Tensor::stack(&t_ndgrid, 0)?.flatten(1, 2)?; + + let tmp1 = coords_flatten + .unsqueeze(2)? + .broadcast_as((2, w_area, w_area))? + .to_dtype(DType::I64)?; + let tmp2 = coords_flatten + .unsqueeze(1)? + .broadcast_as((2, w_area, w_area))? + .to_dtype(DType::I64)?; + let relative_coords = (tmp1 - tmp2)? + .transpose(0, 1)? // 102 + .transpose(1, 2)? // 120 + .contiguous()?; + + let relative_coords = relative_coords.slice_assign( + &[0..w_area, 0..w_area, 0..1], + &(relative_coords.i((0..w_area, 0..w_area, 0..1))? + (WINDOW_SIZE - 1) as f64)?, + )?; + let relative_coords = relative_coords.slice_assign( + &[0..w_area, 0..w_area, 1..2], + &(relative_coords.i((0..w_area, 0..w_area, 1..2))? + (WINDOW_SIZE - 1) as f64)?, + )?; + let relative_coords = relative_coords.slice_assign( + &[0..w_area, 0..w_area, 0..1], + &(relative_coords.i((.., .., 0..1))? * (2. * (WINDOW_SIZE as f64) - 1.))?, + )?; + + Tensor::zeros((w_area + 1, w_area + 1), DType::I64, device)? + .slice_assign(&[1.., 1..], &relative_coords.sum(2)?)? + .slice_assign( + &[0..1, 0..(w_area + 1)], + &(Tensor::ones((1, w_area + 1), DType::I64, device)? + * ((num_relative_distance - 3) as f64))? + .to_dtype(DType::I64)?, + )? + .slice_assign( + &[0..(w_area + 1), 0..1], + &(Tensor::ones((w_area + 1, 1), DType::I64, device)? + * ((num_relative_distance - 2) as f64))? + .to_dtype(DType::I64)?, + )? + .slice_assign( + &[0..1, 0..1], + &(Tensor::ones((1, 1), DType::I64, device)? + * ((num_relative_distance - 1) as f64))? + .to_dtype(DType::I64)?, + ) + } + fn _get_rel_pos_bias(&self) -> Result { self.relative_position_bias_table .index_select( @@ -144,21 +201,9 @@ struct Block { } impl Block { - fn new( - vb: VarBuilder, - dim: usize, - num_heads: usize, - relative_position_index: &Tensor, - ) -> Result { + fn new(vb: VarBuilder, dim: usize, num_heads: usize) -> Result { let norm1 = layer_norm(dim, 1e-6, vb.pp("norm1"))?; - let attn = Attention::new( - vb.pp("attn"), - dim, - num_heads, - true, - true, - relative_position_index, - )?; + let attn = Attention::new(vb.pp("attn"), dim, num_heads, true, true)?; let ls1 = LayerScale::new(vb.pp("ls1"), dim)?; let norm2 = layer_norm(dim, 1e-6, vb.pp("norm2"))?; let mlp = Mlp::new(vb.pp("mlp"), dim, dim * 4, true)?; @@ -240,18 +285,10 @@ impl BeitVisionTransformer { let patch_embed = PatchEmbed::new(vb.pp("patch_embed"), PATCH_SIZE, 3, embed_dim)?; let cls_token = vb.get((1, 1, embed_dim), "cls_token")?; let head = linear(vb.pp("head"), embed_dim, NUM_CLASSES, true)?; - let relative_position_index = vb.get((NB_TOKENS, NB_TOKENS), "relative_position_index")?; let norm = layer_norm(embed_dim, 1e-6, vb.pp("norm"))?; let vb_b = vb.pp("blocks"); let blocks = (0..depth) - .map(|i| { - Block::new( - vb_b.pp(&i.to_string()), - embed_dim, - num_heads, - &relative_position_index, - ) - }) + .map(|i| Block::new(vb_b.pp(&i.to_string()), embed_dim, num_heads)) .collect::>>()?; Ok(Self { patch_embed,