mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Beit: Add the gen_relative_position_index() function (#2306)
Co-authored-by: v-espitalier <>
This commit is contained in:
@ -55,7 +55,7 @@ pub fn main() -> anyhow::Result<()> {
|
|||||||
None => {
|
None => {
|
||||||
let api = hf_hub::api::sync::Api::new()?;
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
let api = api.model("vincent-espitalier/candle-beit".into());
|
let api = api.model("vincent-espitalier/candle-beit".into());
|
||||||
api.get("beit_base_patch16_384.in22k_ft_in22k_in1k_adapted.safetensors")?
|
api.get("beit_base_patch16_384.in22k_ft_in22k_in1k.safetensors")?
|
||||||
}
|
}
|
||||||
Some(model) => model.into(),
|
Some(model) => model.into(),
|
||||||
};
|
};
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
use candle::{DType, IndexOp, Result, Tensor, D};
|
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||||
use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
|
use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
|
||||||
|
|
||||||
const IMG_SIZE: usize = 384;
|
const IMG_SIZE: usize = 384;
|
||||||
@ -32,7 +32,6 @@ impl Attention {
|
|||||||
num_heads: usize,
|
num_heads: usize,
|
||||||
qkv_bias: bool,
|
qkv_bias: bool,
|
||||||
proj_bias: bool,
|
proj_bias: bool,
|
||||||
relative_position_index: &Tensor,
|
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let qkv = linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?;
|
let qkv = linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?;
|
||||||
let proj = linear(vb.pp("proj"), dim, dim, proj_bias)?;
|
let proj = linear(vb.pp("proj"), dim, dim, proj_bias)?;
|
||||||
@ -42,7 +41,8 @@ impl Attention {
|
|||||||
(num_relative_distance, num_heads),
|
(num_relative_distance, num_heads),
|
||||||
"relative_position_bias_table",
|
"relative_position_bias_table",
|
||||||
)?;
|
)?;
|
||||||
let relative_position_index = relative_position_index.clone();
|
let relative_position_index =
|
||||||
|
Self::gen_relative_position_index(relative_position_bias_table.device())?;
|
||||||
let scale = 1. / ((dim / num_heads) as f64).sqrt();
|
let scale = 1. / ((dim / num_heads) as f64).sqrt();
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
qkv,
|
qkv,
|
||||||
@ -56,6 +56,63 @@ impl Attention {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Attention {
|
impl Attention {
|
||||||
|
// See: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/beit.py#L61
|
||||||
|
fn gen_relative_position_index(device: &Device) -> Result<Tensor> {
|
||||||
|
let num_relative_distance = (2 * WINDOW_SIZE - 1) * (2 * WINDOW_SIZE - 1) + 3;
|
||||||
|
let w_area = WINDOW_SIZE * WINDOW_SIZE;
|
||||||
|
|
||||||
|
let t_arange: Tensor = Tensor::arange(0, WINDOW_SIZE as u32, device)?;
|
||||||
|
let t_ndgrid = Tensor::meshgrid(&[&t_arange, &t_arange], false)?;
|
||||||
|
let coords_flatten = Tensor::stack(&t_ndgrid, 0)?.flatten(1, 2)?;
|
||||||
|
|
||||||
|
let tmp1 = coords_flatten
|
||||||
|
.unsqueeze(2)?
|
||||||
|
.broadcast_as((2, w_area, w_area))?
|
||||||
|
.to_dtype(DType::I64)?;
|
||||||
|
let tmp2 = coords_flatten
|
||||||
|
.unsqueeze(1)?
|
||||||
|
.broadcast_as((2, w_area, w_area))?
|
||||||
|
.to_dtype(DType::I64)?;
|
||||||
|
let relative_coords = (tmp1 - tmp2)?
|
||||||
|
.transpose(0, 1)? // 102
|
||||||
|
.transpose(1, 2)? // 120
|
||||||
|
.contiguous()?;
|
||||||
|
|
||||||
|
let relative_coords = relative_coords.slice_assign(
|
||||||
|
&[0..w_area, 0..w_area, 0..1],
|
||||||
|
&(relative_coords.i((0..w_area, 0..w_area, 0..1))? + (WINDOW_SIZE - 1) as f64)?,
|
||||||
|
)?;
|
||||||
|
let relative_coords = relative_coords.slice_assign(
|
||||||
|
&[0..w_area, 0..w_area, 1..2],
|
||||||
|
&(relative_coords.i((0..w_area, 0..w_area, 1..2))? + (WINDOW_SIZE - 1) as f64)?,
|
||||||
|
)?;
|
||||||
|
let relative_coords = relative_coords.slice_assign(
|
||||||
|
&[0..w_area, 0..w_area, 0..1],
|
||||||
|
&(relative_coords.i((.., .., 0..1))? * (2. * (WINDOW_SIZE as f64) - 1.))?,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
Tensor::zeros((w_area + 1, w_area + 1), DType::I64, device)?
|
||||||
|
.slice_assign(&[1.., 1..], &relative_coords.sum(2)?)?
|
||||||
|
.slice_assign(
|
||||||
|
&[0..1, 0..(w_area + 1)],
|
||||||
|
&(Tensor::ones((1, w_area + 1), DType::I64, device)?
|
||||||
|
* ((num_relative_distance - 3) as f64))?
|
||||||
|
.to_dtype(DType::I64)?,
|
||||||
|
)?
|
||||||
|
.slice_assign(
|
||||||
|
&[0..(w_area + 1), 0..1],
|
||||||
|
&(Tensor::ones((w_area + 1, 1), DType::I64, device)?
|
||||||
|
* ((num_relative_distance - 2) as f64))?
|
||||||
|
.to_dtype(DType::I64)?,
|
||||||
|
)?
|
||||||
|
.slice_assign(
|
||||||
|
&[0..1, 0..1],
|
||||||
|
&(Tensor::ones((1, 1), DType::I64, device)?
|
||||||
|
* ((num_relative_distance - 1) as f64))?
|
||||||
|
.to_dtype(DType::I64)?,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
fn _get_rel_pos_bias(&self) -> Result<Tensor> {
|
fn _get_rel_pos_bias(&self) -> Result<Tensor> {
|
||||||
self.relative_position_bias_table
|
self.relative_position_bias_table
|
||||||
.index_select(
|
.index_select(
|
||||||
@ -144,21 +201,9 @@ struct Block {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Block {
|
impl Block {
|
||||||
fn new(
|
fn new(vb: VarBuilder, dim: usize, num_heads: usize) -> Result<Self> {
|
||||||
vb: VarBuilder,
|
|
||||||
dim: usize,
|
|
||||||
num_heads: usize,
|
|
||||||
relative_position_index: &Tensor,
|
|
||||||
) -> Result<Self> {
|
|
||||||
let norm1 = layer_norm(dim, 1e-6, vb.pp("norm1"))?;
|
let norm1 = layer_norm(dim, 1e-6, vb.pp("norm1"))?;
|
||||||
let attn = Attention::new(
|
let attn = Attention::new(vb.pp("attn"), dim, num_heads, true, true)?;
|
||||||
vb.pp("attn"),
|
|
||||||
dim,
|
|
||||||
num_heads,
|
|
||||||
true,
|
|
||||||
true,
|
|
||||||
relative_position_index,
|
|
||||||
)?;
|
|
||||||
let ls1 = LayerScale::new(vb.pp("ls1"), dim)?;
|
let ls1 = LayerScale::new(vb.pp("ls1"), dim)?;
|
||||||
let norm2 = layer_norm(dim, 1e-6, vb.pp("norm2"))?;
|
let norm2 = layer_norm(dim, 1e-6, vb.pp("norm2"))?;
|
||||||
let mlp = Mlp::new(vb.pp("mlp"), dim, dim * 4, true)?;
|
let mlp = Mlp::new(vb.pp("mlp"), dim, dim * 4, true)?;
|
||||||
@ -240,18 +285,10 @@ impl BeitVisionTransformer {
|
|||||||
let patch_embed = PatchEmbed::new(vb.pp("patch_embed"), PATCH_SIZE, 3, embed_dim)?;
|
let patch_embed = PatchEmbed::new(vb.pp("patch_embed"), PATCH_SIZE, 3, embed_dim)?;
|
||||||
let cls_token = vb.get((1, 1, embed_dim), "cls_token")?;
|
let cls_token = vb.get((1, 1, embed_dim), "cls_token")?;
|
||||||
let head = linear(vb.pp("head"), embed_dim, NUM_CLASSES, true)?;
|
let head = linear(vb.pp("head"), embed_dim, NUM_CLASSES, true)?;
|
||||||
let relative_position_index = vb.get((NB_TOKENS, NB_TOKENS), "relative_position_index")?;
|
|
||||||
let norm = layer_norm(embed_dim, 1e-6, vb.pp("norm"))?;
|
let norm = layer_norm(embed_dim, 1e-6, vb.pp("norm"))?;
|
||||||
let vb_b = vb.pp("blocks");
|
let vb_b = vb.pp("blocks");
|
||||||
let blocks = (0..depth)
|
let blocks = (0..depth)
|
||||||
.map(|i| {
|
.map(|i| Block::new(vb_b.pp(&i.to_string()), embed_dim, num_heads))
|
||||||
Block::new(
|
|
||||||
vb_b.pp(&i.to_string()),
|
|
||||||
embed_dim,
|
|
||||||
num_heads,
|
|
||||||
&relative_position_index,
|
|
||||||
)
|
|
||||||
})
|
|
||||||
.collect::<Result<Vec<_>>>()?;
|
.collect::<Result<Vec<_>>>()?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
patch_embed,
|
patch_embed,
|
||||||
|
Reference in New Issue
Block a user