use candle::{DType, IndexOp, Result, Tensor, D}; use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder}; #[derive(Debug)] struct Attention { q_proj: Linear, k_proj: Linear, v_proj: Linear, out_proj: Linear, internal_dim: usize, num_heads: usize, } impl Attention { fn new( embedding_dim: usize, num_heads: usize, downsample_rate: usize, vb: VarBuilder, ) -> Result { let internal_dim = embedding_dim / downsample_rate; let q_proj = candle_nn::linear(embedding_dim, internal_dim, vb.pp("q_proj"))?; let k_proj = candle_nn::linear(embedding_dim, internal_dim, vb.pp("k_proj"))?; let v_proj = candle_nn::linear(embedding_dim, internal_dim, vb.pp("v_proj"))?; let out_proj = candle_nn::linear(internal_dim, embedding_dim, vb.pp("out_proj"))?; Ok(Self { q_proj, k_proj, v_proj, out_proj, internal_dim, num_heads, }) } fn separate_heads(&self, x: &Tensor) -> Result { let (b, n, c) = x.dims3()?; x.reshape((b, n, self.num_heads, c / self.num_heads))? .transpose(1, 2) } fn recombine_heads(&self, x: &Tensor) -> Result { let (b, n_heads, n_tokens, c_per_head) = x.dims4()?; x.transpose(1, 2)? .reshape((b, n_tokens, n_heads * c_per_head)) } fn forward(&self, q: &Tensor, k: &Tensor, v: &Tensor) -> Result { let q = self.q_proj.forward(q)?; let k = self.k_proj.forward(k)?; let v = self.v_proj.forward(v)?; let q = self.separate_heads(&q)?; let k = self.separate_heads(&k)?; let v = self.separate_heads(&v)?; let (_, _, _, c_per_head) = q.dims4()?; let attn = (q.matmul(&k.t()?)? / (c_per_head as f64).sqrt())?; let attn = candle_nn::ops::softmax_last_dim(&attn)?; let out = attn.matmul(&v)?; self.recombine_heads(&out)?.apply(&self.out_proj) } } #[derive(Debug)] struct TwoWayAttentionBlock { self_attn: Attention, norm1: LayerNorm, cross_attn_token_to_image: Attention, norm2: LayerNorm, mlp: crate::MlpBlock, norm3: LayerNorm, norm4: LayerNorm, cross_attn_image_to_token: Attention, skip_first_layer_pe: bool, } impl TwoWayAttentionBlock { fn new( embedding_dim: usize, num_heads: usize, mlp_dim: usize, skip_first_layer_pe: bool, vb: VarBuilder, ) -> Result { let self_attn = Attention::new(embedding_dim, num_heads, 1, vb.pp("self_attn"))?; let norm1 = layer_norm(embedding_dim, 1e-5, vb.pp("norm1"))?; let norm2 = layer_norm(embedding_dim, 1e-5, vb.pp("norm2"))?; let norm3 = layer_norm(embedding_dim, 1e-5, vb.pp("norm3"))?; let norm4 = layer_norm(embedding_dim, 1e-5, vb.pp("norm4"))?; let self_attn = Attention::new(embedding_dim, num_heads, 1, vb.pp("self_attn"))?; let cross_attn_token_to_image = Attention::new( embedding_dim, num_heads, 2, vb.pp("cross_attn_token_to_image"), )?; let cross_attn_image_to_token = Attention::new( embedding_dim, num_heads, 2, vb.pp("cross_attn_image_to_token"), )?; // TODO: use relu in this mlp let mlp = crate::MlpBlock::new(embedding_dim, mlp_dim, vb.pp("mlp"))?; Ok(Self { self_attn, norm1, cross_attn_image_to_token, norm2, mlp, norm3, norm4, cross_attn_token_to_image, skip_first_layer_pe, }) } fn forward( &self, queries: &Tensor, keys: &Tensor, query_pe: &Tensor, key_pe: &Tensor, ) -> Result<(Tensor, Tensor)> { // Self attention block let queries = if self.skip_first_layer_pe { self.self_attn.forward(queries, keys, queries)? } else { let q = (queries + query_pe)?; let attn_out = self.self_attn.forward(&q, &q, queries)?; (queries + attn_out)? }; let queries = self.norm1.forward(&queries)?; // Cross attention block, tokens attending to image embedding let q = (&queries + query_pe)?; let k = (keys + key_pe)?; let attn_out = self.cross_attn_token_to_image.forward(&q, &k, keys)?; let queries = (&queries + attn_out)?; let queries = self.norm2.forward(&queries)?; // MLP block let mlp_out = self.mlp.forward(&queries); let queries = (queries + mlp_out)?; let queries = self.norm3.forward(&queries)?; // Cross attention block, image embedding attending to tokens let q = (&queries + query_pe)?; let k = (keys + key_pe)?; let attn_out = self.cross_attn_image_to_token.forward(&k, &q, &queries)?; let keys = (keys + attn_out)?; let keys = self.norm4.forward(&keys)?; Ok((queries, keys)) } } #[derive(Debug)] pub struct TwoWayTransformer { layers: Vec, final_attn_token_to_image: Attention, norm_final_attn: LayerNorm, } impl TwoWayTransformer { pub fn new( depth: usize, embedding_dim: usize, num_heads: usize, mlp_dim: usize, vb: VarBuilder, ) -> Result { let vb_l = vb.pp("layers"); let mut layers = Vec::with_capacity(depth); for i in 0..depth { let layer = TwoWayAttentionBlock::new(embedding_dim, num_heads, mlp_dim, i == 0, vb_l.pp(i))?; layers.push(layer) } let final_attn_token_to_image = Attention::new( embedding_dim, num_heads, 2, vb.pp("final_attn_token_to_image"), )?; let norm_final_attn = layer_norm(embedding_dim, 1e-5, vb.pp("norm_final_attn"))?; Ok(Self { layers, final_attn_token_to_image, norm_final_attn, }) } pub fn forward( &self, image_embedding: &Tensor, image_pe: &Tensor, point_embedding: &Tensor, ) -> Result<(Tensor, Tensor)> { let (bs, c, h, w) = image_embedding.dims4()?; let image_embedding = image_embedding.flatten_from(2)?.permute((0, 2, 1))?; let image_pe = image_pe.flatten_from(2)?.permute((0, 2, 1))?; let mut queries = point_embedding.clone(); let mut keys = image_embedding; for layer in self.layers.iter() { (queries, keys) = layer.forward(&queries, &keys, point_embedding, &image_pe)? } let q = (&queries + point_embedding)?; let k = (&keys + image_pe)?; let attn_out = self.final_attn_token_to_image.forward(&q, &k, &keys)?; let queries = (queries + attn_out)?.apply(&self.norm_final_attn)?; Ok((queries, keys)) } }