mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
More segment-anything again. (#764)
* More segment-anything again. * Transformer block forward. * Two-ways transformer. * Position embeddings. * Sketch the prompt encoder. * More prompt-encoder. * More prompt-encoder. * Add the main sam module. * Embed the transformer. * And hook the transformer forward step. * Build the model. * Handle the global attn indexes. * Get the model to load.
This commit is contained in:
@ -75,3 +75,146 @@ struct TwoWayAttentionBlock {
|
||||
cross_attn_image_to_token: Attention,
|
||||
skip_first_layer_pe: bool,
|
||||
}
|
||||
|
||||
impl TwoWayAttentionBlock {
|
||||
fn new(
|
||||
embedding_dim: usize,
|
||||
num_heads: usize,
|
||||
mlp_dim: usize,
|
||||
skip_first_layer_pe: bool,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let self_attn = Attention::new(embedding_dim, num_heads, 1, vb.pp("self_attn"))?;
|
||||
let norm1 = layer_norm(embedding_dim, 1e-5, vb.pp("norm1"))?;
|
||||
let norm2 = layer_norm(embedding_dim, 1e-5, vb.pp("norm2"))?;
|
||||
let norm3 = layer_norm(embedding_dim, 1e-5, vb.pp("norm3"))?;
|
||||
let norm4 = layer_norm(embedding_dim, 1e-5, vb.pp("norm4"))?;
|
||||
let self_attn = Attention::new(embedding_dim, num_heads, 1, vb.pp("self_attn"))?;
|
||||
let cross_attn_token_to_image = Attention::new(
|
||||
embedding_dim,
|
||||
num_heads,
|
||||
2,
|
||||
vb.pp("cross_attn_token_to_image"),
|
||||
)?;
|
||||
let cross_attn_image_to_token = Attention::new(
|
||||
embedding_dim,
|
||||
num_heads,
|
||||
2,
|
||||
vb.pp("cross_attn_image_to_token"),
|
||||
)?;
|
||||
// TODO: use relu in this mlp
|
||||
let mlp = crate::MlpBlock::new(embedding_dim, mlp_dim, vb.pp("mlp"))?;
|
||||
Ok(Self {
|
||||
self_attn,
|
||||
norm1,
|
||||
cross_attn_image_to_token,
|
||||
norm2,
|
||||
mlp,
|
||||
norm3,
|
||||
norm4,
|
||||
cross_attn_token_to_image,
|
||||
skip_first_layer_pe,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&self,
|
||||
queries: &Tensor,
|
||||
keys: &Tensor,
|
||||
query_pe: &Tensor,
|
||||
key_pe: &Tensor,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
// Self attention block
|
||||
let queries = if self.skip_first_layer_pe {
|
||||
self.self_attn.forward(queries, keys, queries)?
|
||||
} else {
|
||||
let q = (queries + query_pe)?;
|
||||
let attn_out = self.self_attn.forward(&q, &q, queries)?;
|
||||
(queries + attn_out)?
|
||||
};
|
||||
let queries = self.norm1.forward(&queries)?;
|
||||
|
||||
// Cross attention block, tokens attending to image embedding
|
||||
let q = (&queries + query_pe)?;
|
||||
let k = (keys + key_pe)?;
|
||||
let attn_out = self.cross_attn_token_to_image.forward(&q, &k, keys)?;
|
||||
let queries = (&queries + attn_out)?;
|
||||
let queries = self.norm2.forward(&queries)?;
|
||||
|
||||
// MLP block
|
||||
let mlp_out = self.mlp.forward(&queries);
|
||||
let queries = (queries + mlp_out)?;
|
||||
let queries = self.norm3.forward(&queries)?;
|
||||
|
||||
// Cross attention block, image embedding attending to tokens
|
||||
let q = (&queries + query_pe)?;
|
||||
let k = (keys + key_pe)?;
|
||||
let attn_out = self.cross_attn_image_to_token.forward(&k, &q, &queries)?;
|
||||
let keys = (keys + attn_out)?;
|
||||
let keys = self.norm4.forward(&keys)?;
|
||||
|
||||
Ok((queries, keys))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct TwoWayTransformer {
|
||||
layers: Vec<TwoWayAttentionBlock>,
|
||||
final_attn_token_to_image: Attention,
|
||||
norm_final_attn: LayerNorm,
|
||||
}
|
||||
|
||||
impl TwoWayTransformer {
|
||||
pub fn new(
|
||||
depth: usize,
|
||||
embedding_dim: usize,
|
||||
num_heads: usize,
|
||||
mlp_dim: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let vb_l = vb.pp("layers");
|
||||
let mut layers = Vec::with_capacity(depth);
|
||||
for i in 0..depth {
|
||||
let layer =
|
||||
TwoWayAttentionBlock::new(embedding_dim, num_heads, mlp_dim, i == 0, vb_l.pp(i))?;
|
||||
layers.push(layer)
|
||||
}
|
||||
let final_attn_token_to_image = Attention::new(
|
||||
embedding_dim,
|
||||
num_heads,
|
||||
2,
|
||||
vb.pp("final_attn_token_to_image"),
|
||||
)?;
|
||||
let norm_final_attn = layer_norm(embedding_dim, 1e-5, vb.pp("norm_final_attn"))?;
|
||||
Ok(Self {
|
||||
layers,
|
||||
final_attn_token_to_image,
|
||||
norm_final_attn,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(
|
||||
&self,
|
||||
image_embedding: &Tensor,
|
||||
image_pe: &Tensor,
|
||||
point_embedding: &Tensor,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
let (bs, c, h, w) = image_embedding.dims4()?;
|
||||
let image_embedding = image_embedding.flatten_from(2)?.permute((0, 2, 1))?;
|
||||
let image_pe = image_pe.flatten_from(2)?.permute((0, 2, 1))?;
|
||||
|
||||
let mut queries = point_embedding.clone();
|
||||
let mut keys = image_embedding;
|
||||
|
||||
for layer in self.layers.iter() {
|
||||
(queries, keys) = layer.forward(&queries, &keys, point_embedding, &image_pe)?
|
||||
}
|
||||
|
||||
let q = (&queries + point_embedding)?;
|
||||
let k = (&keys + image_pe)?;
|
||||
let attn_out = self.final_attn_token_to_image.forward(&q, &k, &keys)?;
|
||||
let queries = (queries + attn_out)?.apply(&self.norm_final_attn)?;
|
||||
|
||||
Ok((queries, keys))
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user