mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Beit: Add the gen_relative_position_index() function (#2306)
Co-authored-by: v-espitalier <>
This commit is contained in:
@ -55,7 +55,7 @@ pub fn main() -> anyhow::Result<()> {
|
||||
None => {
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.model("vincent-espitalier/candle-beit".into());
|
||||
api.get("beit_base_patch16_384.in22k_ft_in22k_in1k_adapted.safetensors")?
|
||||
api.get("beit_base_patch16_384.in22k_ft_in22k_in1k.safetensors")?
|
||||
}
|
||||
Some(model) => model.into(),
|
||||
};
|
||||
|
@ -1,4 +1,4 @@
|
||||
use candle::{DType, IndexOp, Result, Tensor, D};
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
|
||||
|
||||
const IMG_SIZE: usize = 384;
|
||||
@ -32,7 +32,6 @@ impl Attention {
|
||||
num_heads: usize,
|
||||
qkv_bias: bool,
|
||||
proj_bias: bool,
|
||||
relative_position_index: &Tensor,
|
||||
) -> Result<Self> {
|
||||
let qkv = linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?;
|
||||
let proj = linear(vb.pp("proj"), dim, dim, proj_bias)?;
|
||||
@ -42,7 +41,8 @@ impl Attention {
|
||||
(num_relative_distance, num_heads),
|
||||
"relative_position_bias_table",
|
||||
)?;
|
||||
let relative_position_index = relative_position_index.clone();
|
||||
let relative_position_index =
|
||||
Self::gen_relative_position_index(relative_position_bias_table.device())?;
|
||||
let scale = 1. / ((dim / num_heads) as f64).sqrt();
|
||||
Ok(Self {
|
||||
qkv,
|
||||
@ -56,6 +56,63 @@ impl Attention {
|
||||
}
|
||||
|
||||
impl Attention {
|
||||
// See: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/beit.py#L61
|
||||
fn gen_relative_position_index(device: &Device) -> Result<Tensor> {
|
||||
let num_relative_distance = (2 * WINDOW_SIZE - 1) * (2 * WINDOW_SIZE - 1) + 3;
|
||||
let w_area = WINDOW_SIZE * WINDOW_SIZE;
|
||||
|
||||
let t_arange: Tensor = Tensor::arange(0, WINDOW_SIZE as u32, device)?;
|
||||
let t_ndgrid = Tensor::meshgrid(&[&t_arange, &t_arange], false)?;
|
||||
let coords_flatten = Tensor::stack(&t_ndgrid, 0)?.flatten(1, 2)?;
|
||||
|
||||
let tmp1 = coords_flatten
|
||||
.unsqueeze(2)?
|
||||
.broadcast_as((2, w_area, w_area))?
|
||||
.to_dtype(DType::I64)?;
|
||||
let tmp2 = coords_flatten
|
||||
.unsqueeze(1)?
|
||||
.broadcast_as((2, w_area, w_area))?
|
||||
.to_dtype(DType::I64)?;
|
||||
let relative_coords = (tmp1 - tmp2)?
|
||||
.transpose(0, 1)? // 102
|
||||
.transpose(1, 2)? // 120
|
||||
.contiguous()?;
|
||||
|
||||
let relative_coords = relative_coords.slice_assign(
|
||||
&[0..w_area, 0..w_area, 0..1],
|
||||
&(relative_coords.i((0..w_area, 0..w_area, 0..1))? + (WINDOW_SIZE - 1) as f64)?,
|
||||
)?;
|
||||
let relative_coords = relative_coords.slice_assign(
|
||||
&[0..w_area, 0..w_area, 1..2],
|
||||
&(relative_coords.i((0..w_area, 0..w_area, 1..2))? + (WINDOW_SIZE - 1) as f64)?,
|
||||
)?;
|
||||
let relative_coords = relative_coords.slice_assign(
|
||||
&[0..w_area, 0..w_area, 0..1],
|
||||
&(relative_coords.i((.., .., 0..1))? * (2. * (WINDOW_SIZE as f64) - 1.))?,
|
||||
)?;
|
||||
|
||||
Tensor::zeros((w_area + 1, w_area + 1), DType::I64, device)?
|
||||
.slice_assign(&[1.., 1..], &relative_coords.sum(2)?)?
|
||||
.slice_assign(
|
||||
&[0..1, 0..(w_area + 1)],
|
||||
&(Tensor::ones((1, w_area + 1), DType::I64, device)?
|
||||
* ((num_relative_distance - 3) as f64))?
|
||||
.to_dtype(DType::I64)?,
|
||||
)?
|
||||
.slice_assign(
|
||||
&[0..(w_area + 1), 0..1],
|
||||
&(Tensor::ones((w_area + 1, 1), DType::I64, device)?
|
||||
* ((num_relative_distance - 2) as f64))?
|
||||
.to_dtype(DType::I64)?,
|
||||
)?
|
||||
.slice_assign(
|
||||
&[0..1, 0..1],
|
||||
&(Tensor::ones((1, 1), DType::I64, device)?
|
||||
* ((num_relative_distance - 1) as f64))?
|
||||
.to_dtype(DType::I64)?,
|
||||
)
|
||||
}
|
||||
|
||||
fn _get_rel_pos_bias(&self) -> Result<Tensor> {
|
||||
self.relative_position_bias_table
|
||||
.index_select(
|
||||
@ -144,21 +201,9 @@ struct Block {
|
||||
}
|
||||
|
||||
impl Block {
|
||||
fn new(
|
||||
vb: VarBuilder,
|
||||
dim: usize,
|
||||
num_heads: usize,
|
||||
relative_position_index: &Tensor,
|
||||
) -> Result<Self> {
|
||||
fn new(vb: VarBuilder, dim: usize, num_heads: usize) -> Result<Self> {
|
||||
let norm1 = layer_norm(dim, 1e-6, vb.pp("norm1"))?;
|
||||
let attn = Attention::new(
|
||||
vb.pp("attn"),
|
||||
dim,
|
||||
num_heads,
|
||||
true,
|
||||
true,
|
||||
relative_position_index,
|
||||
)?;
|
||||
let attn = Attention::new(vb.pp("attn"), dim, num_heads, true, true)?;
|
||||
let ls1 = LayerScale::new(vb.pp("ls1"), dim)?;
|
||||
let norm2 = layer_norm(dim, 1e-6, vb.pp("norm2"))?;
|
||||
let mlp = Mlp::new(vb.pp("mlp"), dim, dim * 4, true)?;
|
||||
@ -240,18 +285,10 @@ impl BeitVisionTransformer {
|
||||
let patch_embed = PatchEmbed::new(vb.pp("patch_embed"), PATCH_SIZE, 3, embed_dim)?;
|
||||
let cls_token = vb.get((1, 1, embed_dim), "cls_token")?;
|
||||
let head = linear(vb.pp("head"), embed_dim, NUM_CLASSES, true)?;
|
||||
let relative_position_index = vb.get((NB_TOKENS, NB_TOKENS), "relative_position_index")?;
|
||||
let norm = layer_norm(embed_dim, 1e-6, vb.pp("norm"))?;
|
||||
let vb_b = vb.pp("blocks");
|
||||
let blocks = (0..depth)
|
||||
.map(|i| {
|
||||
Block::new(
|
||||
vb_b.pp(&i.to_string()),
|
||||
embed_dim,
|
||||
num_heads,
|
||||
&relative_position_index,
|
||||
)
|
||||
})
|
||||
.map(|i| Block::new(vb_b.pp(&i.to_string()), embed_dim, num_heads))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
Ok(Self {
|
||||
patch_embed,
|
||||
|
Reference in New Issue
Block a user