mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
More segment-anything again. (#764)
* More segment-anything again. * Transformer block forward. * Two-ways transformer. * Position embeddings. * Sketch the prompt encoder. * More prompt-encoder. * More prompt-encoder. * Add the main sam module. * Embed the transformer. * And hook the transformer forward step. * Build the model. * Handle the global attn indexes. * Get the model to load.
This commit is contained in:
@ -47,7 +47,7 @@ impl Attention {
|
||||
num_heads: usize,
|
||||
qkv_bias: bool,
|
||||
use_rel_pos: bool,
|
||||
window_size: usize,
|
||||
input_size: (usize, usize),
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let qkv = crate::linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?;
|
||||
@ -55,8 +55,8 @@ impl Attention {
|
||||
let head_dim = dim / num_heads;
|
||||
let scale = 1. / (head_dim as f64).sqrt();
|
||||
let rel_pos_hw = if use_rel_pos {
|
||||
let h = vb.get((2 * window_size - 1, head_dim), "rel_pos_h")?;
|
||||
let w = vb.get((2 * window_size - 1, head_dim), "rel_pos_w")?;
|
||||
let h = vb.get((2 * input_size.0 - 1, head_dim), "rel_pos_h")?;
|
||||
let w = vb.get((2 * input_size.1 - 1, head_dim), "rel_pos_w")?;
|
||||
Some((h, w))
|
||||
} else {
|
||||
None
|
||||
@ -114,16 +114,22 @@ impl Block {
|
||||
qkv_bias: bool,
|
||||
use_rel_pos: bool,
|
||||
window_size: usize,
|
||||
input_size: (usize, usize),
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let norm1 = layer_norm(dim, 1e-5, vb.pp("norm1"))?;
|
||||
let norm2 = layer_norm(dim, 1e-5, vb.pp("norm2"))?;
|
||||
let input_size_attn = if window_size == 0 {
|
||||
input_size
|
||||
} else {
|
||||
(window_size, window_size)
|
||||
};
|
||||
let attn = Attention::new(
|
||||
dim,
|
||||
num_heads,
|
||||
qkv_bias,
|
||||
use_rel_pos,
|
||||
window_size,
|
||||
input_size_attn,
|
||||
vb.pp("attn"),
|
||||
)?;
|
||||
let mlp = crate::MlpBlock::new(dim, dim * 4, vb.pp("mlp"))?;
|
||||
@ -154,7 +160,7 @@ impl Module for Block {
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct ImageEncoderViT {
|
||||
pub struct ImageEncoderViT {
|
||||
img_size: usize,
|
||||
patch_embed: PatchEmbed,
|
||||
blocks: Vec<Block>,
|
||||
@ -167,7 +173,7 @@ struct ImageEncoderViT {
|
||||
|
||||
impl ImageEncoderViT {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn new(
|
||||
pub fn new(
|
||||
img_size: usize,
|
||||
patch_size: usize,
|
||||
in_chans: usize,
|
||||
@ -179,6 +185,7 @@ impl ImageEncoderViT {
|
||||
use_rel_pos: bool,
|
||||
use_abs_pos: bool,
|
||||
window_size: usize,
|
||||
global_attn_indexes: &[usize],
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let patch_embed = PatchEmbed::new(
|
||||
@ -192,12 +199,18 @@ impl ImageEncoderViT {
|
||||
let mut blocks = Vec::with_capacity(depth);
|
||||
let vb_b = vb.pp("blocks");
|
||||
for i in 0..depth {
|
||||
let window_size = if global_attn_indexes.contains(&i) {
|
||||
0
|
||||
} else {
|
||||
window_size
|
||||
};
|
||||
let block = Block::new(
|
||||
embed_dim,
|
||||
num_heads,
|
||||
qkv_bias,
|
||||
use_rel_pos,
|
||||
window_size,
|
||||
(img_size / patch_size, img_size / patch_size),
|
||||
vb_b.pp(i),
|
||||
)?;
|
||||
blocks.push(block)
|
||||
|
Reference in New Issue
Block a user