Beit: Add the gen_relative_position_index() function (#2306)

Co-authored-by: v-espitalier <>
This commit is contained in:
v-espitalier
2024-07-04 09:45:26 +02:00
committed by GitHub
parent 7f1ba8038c
commit ecff05d72b
2 changed files with 64 additions and 27 deletions

View File

@ -55,7 +55,7 @@ pub fn main() -> anyhow::Result<()> {
None => {
let api = hf_hub::api::sync::Api::new()?;
let api = api.model("vincent-espitalier/candle-beit".into());
api.get("beit_base_patch16_384.in22k_ft_in22k_in1k_adapted.safetensors")?
api.get("beit_base_patch16_384.in22k_ft_in22k_in1k.safetensors")?
}
Some(model) => model.into(),
};

View File

@ -1,4 +1,4 @@
use candle::{DType, IndexOp, Result, Tensor, D};
use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
const IMG_SIZE: usize = 384;
@ -32,7 +32,6 @@ impl Attention {
num_heads: usize,
qkv_bias: bool,
proj_bias: bool,
relative_position_index: &Tensor,
) -> Result<Self> {
let qkv = linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?;
let proj = linear(vb.pp("proj"), dim, dim, proj_bias)?;
@ -42,7 +41,8 @@ impl Attention {
(num_relative_distance, num_heads),
"relative_position_bias_table",
)?;
let relative_position_index = relative_position_index.clone();
let relative_position_index =
Self::gen_relative_position_index(relative_position_bias_table.device())?;
let scale = 1. / ((dim / num_heads) as f64).sqrt();
Ok(Self {
qkv,
@ -56,6 +56,63 @@ impl Attention {
}
impl Attention {
// See: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/beit.py#L61
fn gen_relative_position_index(device: &Device) -> Result<Tensor> {
let num_relative_distance = (2 * WINDOW_SIZE - 1) * (2 * WINDOW_SIZE - 1) + 3;
let w_area = WINDOW_SIZE * WINDOW_SIZE;
let t_arange: Tensor = Tensor::arange(0, WINDOW_SIZE as u32, device)?;
let t_ndgrid = Tensor::meshgrid(&[&t_arange, &t_arange], false)?;
let coords_flatten = Tensor::stack(&t_ndgrid, 0)?.flatten(1, 2)?;
let tmp1 = coords_flatten
.unsqueeze(2)?
.broadcast_as((2, w_area, w_area))?
.to_dtype(DType::I64)?;
let tmp2 = coords_flatten
.unsqueeze(1)?
.broadcast_as((2, w_area, w_area))?
.to_dtype(DType::I64)?;
let relative_coords = (tmp1 - tmp2)?
.transpose(0, 1)? // 102
.transpose(1, 2)? // 120
.contiguous()?;
let relative_coords = relative_coords.slice_assign(
&[0..w_area, 0..w_area, 0..1],
&(relative_coords.i((0..w_area, 0..w_area, 0..1))? + (WINDOW_SIZE - 1) as f64)?,
)?;
let relative_coords = relative_coords.slice_assign(
&[0..w_area, 0..w_area, 1..2],
&(relative_coords.i((0..w_area, 0..w_area, 1..2))? + (WINDOW_SIZE - 1) as f64)?,
)?;
let relative_coords = relative_coords.slice_assign(
&[0..w_area, 0..w_area, 0..1],
&(relative_coords.i((.., .., 0..1))? * (2. * (WINDOW_SIZE as f64) - 1.))?,
)?;
Tensor::zeros((w_area + 1, w_area + 1), DType::I64, device)?
.slice_assign(&[1.., 1..], &relative_coords.sum(2)?)?
.slice_assign(
&[0..1, 0..(w_area + 1)],
&(Tensor::ones((1, w_area + 1), DType::I64, device)?
* ((num_relative_distance - 3) as f64))?
.to_dtype(DType::I64)?,
)?
.slice_assign(
&[0..(w_area + 1), 0..1],
&(Tensor::ones((w_area + 1, 1), DType::I64, device)?
* ((num_relative_distance - 2) as f64))?
.to_dtype(DType::I64)?,
)?
.slice_assign(
&[0..1, 0..1],
&(Tensor::ones((1, 1), DType::I64, device)?
* ((num_relative_distance - 1) as f64))?
.to_dtype(DType::I64)?,
)
}
fn _get_rel_pos_bias(&self) -> Result<Tensor> {
self.relative_position_bias_table
.index_select(
@ -144,21 +201,9 @@ struct Block {
}
impl Block {
fn new(
vb: VarBuilder,
dim: usize,
num_heads: usize,
relative_position_index: &Tensor,
) -> Result<Self> {
fn new(vb: VarBuilder, dim: usize, num_heads: usize) -> Result<Self> {
let norm1 = layer_norm(dim, 1e-6, vb.pp("norm1"))?;
let attn = Attention::new(
vb.pp("attn"),
dim,
num_heads,
true,
true,
relative_position_index,
)?;
let attn = Attention::new(vb.pp("attn"), dim, num_heads, true, true)?;
let ls1 = LayerScale::new(vb.pp("ls1"), dim)?;
let norm2 = layer_norm(dim, 1e-6, vb.pp("norm2"))?;
let mlp = Mlp::new(vb.pp("mlp"), dim, dim * 4, true)?;
@ -240,18 +285,10 @@ impl BeitVisionTransformer {
let patch_embed = PatchEmbed::new(vb.pp("patch_embed"), PATCH_SIZE, 3, embed_dim)?;
let cls_token = vb.get((1, 1, embed_dim), "cls_token")?;
let head = linear(vb.pp("head"), embed_dim, NUM_CLASSES, true)?;
let relative_position_index = vb.get((NB_TOKENS, NB_TOKENS), "relative_position_index")?;
let norm = layer_norm(embed_dim, 1e-6, vb.pp("norm"))?;
let vb_b = vb.pp("blocks");
let blocks = (0..depth)
.map(|i| {
Block::new(
vb_b.pp(&i.to_string()),
embed_dim,
num_heads,
&relative_position_index,
)
})
.map(|i| Block::new(vb_b.pp(&i.to_string()), embed_dim, num_heads))
.collect::<Result<Vec<_>>>()?;
Ok(Self {
patch_embed,