More segment-anything again. (#764)

* More segment-anything again.

* Transformer block forward.

* Two-ways transformer.

* Position embeddings.

* Sketch the prompt encoder.

* More prompt-encoder.

* More prompt-encoder.

* Add the main sam module.

* Embed the transformer.

* And hook the transformer forward step.

* Build the model.

* Handle the global attn indexes.

* Get the model to load.
This commit is contained in:
Laurent Mazare
2023-09-07 13:06:55 +02:00
committed by GitHub
parent 8c991df394
commit 7b50f3e106
6 changed files with 454 additions and 20 deletions

View File

@ -47,7 +47,7 @@ impl Attention {
num_heads: usize,
qkv_bias: bool,
use_rel_pos: bool,
window_size: usize,
input_size: (usize, usize),
vb: VarBuilder,
) -> Result<Self> {
let qkv = crate::linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?;
@ -55,8 +55,8 @@ impl Attention {
let head_dim = dim / num_heads;
let scale = 1. / (head_dim as f64).sqrt();
let rel_pos_hw = if use_rel_pos {
let h = vb.get((2 * window_size - 1, head_dim), "rel_pos_h")?;
let w = vb.get((2 * window_size - 1, head_dim), "rel_pos_w")?;
let h = vb.get((2 * input_size.0 - 1, head_dim), "rel_pos_h")?;
let w = vb.get((2 * input_size.1 - 1, head_dim), "rel_pos_w")?;
Some((h, w))
} else {
None
@ -114,16 +114,22 @@ impl Block {
qkv_bias: bool,
use_rel_pos: bool,
window_size: usize,
input_size: (usize, usize),
vb: VarBuilder,
) -> Result<Self> {
let norm1 = layer_norm(dim, 1e-5, vb.pp("norm1"))?;
let norm2 = layer_norm(dim, 1e-5, vb.pp("norm2"))?;
let input_size_attn = if window_size == 0 {
input_size
} else {
(window_size, window_size)
};
let attn = Attention::new(
dim,
num_heads,
qkv_bias,
use_rel_pos,
window_size,
input_size_attn,
vb.pp("attn"),
)?;
let mlp = crate::MlpBlock::new(dim, dim * 4, vb.pp("mlp"))?;
@ -154,7 +160,7 @@ impl Module for Block {
}
#[derive(Debug)]
struct ImageEncoderViT {
pub struct ImageEncoderViT {
img_size: usize,
patch_embed: PatchEmbed,
blocks: Vec<Block>,
@ -167,7 +173,7 @@ struct ImageEncoderViT {
impl ImageEncoderViT {
#[allow(clippy::too_many_arguments)]
fn new(
pub fn new(
img_size: usize,
patch_size: usize,
in_chans: usize,
@ -179,6 +185,7 @@ impl ImageEncoderViT {
use_rel_pos: bool,
use_abs_pos: bool,
window_size: usize,
global_attn_indexes: &[usize],
vb: VarBuilder,
) -> Result<Self> {
let patch_embed = PatchEmbed::new(
@ -192,12 +199,18 @@ impl ImageEncoderViT {
let mut blocks = Vec::with_capacity(depth);
let vb_b = vb.pp("blocks");
for i in 0..depth {
let window_size = if global_attn_indexes.contains(&i) {
0
} else {
window_size
};
let block = Block::new(
embed_dim,
num_heads,
qkv_bias,
use_rel_pos,
window_size,
(img_size / patch_size, img_size / patch_size),
vb_b.pp(i),
)?;
blocks.push(block)