diff --git a/candle-examples/examples/segment-anything/main.rs b/candle-examples/examples/segment-anything/main.rs index a53cff8b..de16f70c 100644 --- a/candle-examples/examples/segment-anything/main.rs +++ b/candle-examples/examples/segment-anything/main.rs @@ -8,16 +8,15 @@ extern crate intel_mkl_src; #[cfg(feature = "accelerate")] extern crate accelerate_src; -use clap::Parser; +mod model_image_encoder; +mod model_mask_decoder; +mod model_transformer; use candle::{DType, IndexOp, Result, Tensor, D}; use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder}; +use clap::Parser; -const IMG_SIZE: usize = 518; -const PATCH_SIZE: usize = 14; -const NUM_CLASSES: usize = 1000; - -fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result { +pub fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result { if bias { candle_nn::linear(in_dim, out_dim, vb) } else { @@ -26,13 +25,13 @@ fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result Result { + pub fn new(embedding_dim: usize, mlp_dim: usize, vb: VarBuilder) -> Result { let lin1 = candle_nn::linear(embedding_dim, mlp_dim, vb.pp("lin1"))?; let lin2 = candle_nn::linear(mlp_dim, embedding_dim, vb.pp("lin2"))?; Ok(Self { lin1, lin2 }) @@ -45,347 +44,6 @@ impl Module for MlpBlock { } } -#[derive(Debug)] -struct PatchEmbed { - proj: candle_nn::Conv2d, -} - -impl PatchEmbed { - fn new( - in_chans: usize, - embed_dim: usize, - k_size: usize, - stride: usize, - padding: usize, - vb: VarBuilder, - ) -> Result { - let cfg = candle_nn::Conv2dConfig { - stride, - padding, - ..Default::default() - }; - let proj = candle_nn::conv2d(in_chans, embed_dim, k_size, cfg, vb.pp("proj"))?; - Ok(Self { proj }) - } -} - -impl Module for PatchEmbed { - fn forward(&self, xs: &Tensor) -> Result { - xs.apply(&self.proj)?.permute((0, 2, 3, 1)) - } -} - -#[derive(Debug)] -struct Attention { - qkv: Linear, - proj: Linear, - num_heads: usize, - scale: f64, - use_rel_pos: bool, - rel_pos_hw: Option<(Tensor, Tensor)>, -} - -impl Attention { - fn new( - dim: usize, - num_heads: usize, - qkv_bias: bool, - use_rel_pos: bool, - window_size: usize, - vb: VarBuilder, - ) -> Result { - let qkv = linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?; - let proj = linear(vb.pp("proj"), dim, dim, true)?; - let head_dim = dim / num_heads; - let scale = 1. / (head_dim as f64).sqrt(); - let rel_pos_hw = if use_rel_pos { - let h = vb.get((2 * window_size - 1, head_dim), "rel_pos_h")?; - let w = vb.get((2 * window_size - 1, head_dim), "rel_pos_w")?; - Some((h, w)) - } else { - None - }; - Ok(Self { - qkv, - proj, - num_heads, - scale, - use_rel_pos, - rel_pos_hw, - }) - } -} - -impl Module for Attention { - fn forward(&self, xs: &Tensor) -> Result { - let (b, h, w, c) = xs.dims4()?; - let qkv = self - .qkv - .forward(xs)? - .reshape((b, h * w, 3, self.num_heads, c / self.num_heads))? - .permute((2, 0, 3, 1, 4))? - .reshape((3, b * self.num_heads, h * w, c / self.num_heads))?; - let q = qkv.i(0)?; - let k = qkv.i(1)?; - let v = qkv.i(2)?; - let attn = (q * self.scale)?.matmul(&k.t()?)?; - if self.use_rel_pos { - todo!() - } - let attn = candle_nn::ops::softmax_last_dim(&attn)?; - let attn = attn - .matmul(&v)? - .reshape((b, self.num_heads, h, w, c / self.num_heads))? - .permute((0, 2, 3, 1, 4))? - .reshape((b, h, w, c / self.num_heads))?; - self.proj.forward(&attn) - } -} - -#[derive(Debug)] -struct Block { - norm1: LayerNorm, - attn: Attention, - norm2: LayerNorm, - mlp: MlpBlock, - window_size: usize, -} - -impl Block { - fn new( - dim: usize, - num_heads: usize, - qkv_bias: bool, - use_rel_pos: bool, - window_size: usize, - vb: VarBuilder, - ) -> Result { - let norm1 = layer_norm(dim, 1e-5, vb.pp("norm1"))?; - let norm2 = layer_norm(dim, 1e-5, vb.pp("norm2"))?; - let attn = Attention::new( - dim, - num_heads, - qkv_bias, - use_rel_pos, - window_size, - vb.pp("attn"), - )?; - let mlp = MlpBlock::new(dim, dim * 4, vb.pp("mlp"))?; - Ok(Self { - norm1, - attn, - norm2, - mlp, - window_size, - }) - } -} - -impl Module for Block { - fn forward(&self, xs: &Tensor) -> Result { - let shortcut = xs; - let xs = self.norm1.forward(xs)?; - if self.window_size > 0 { - todo!() - } - let xs = self.attn.forward(&xs)?; - if self.window_size > 0 { - todo!() - } - let xs = (xs + shortcut)?; - &xs + xs.apply(&self.norm2)?.apply(&self.mlp)? - } -} - -#[derive(Debug)] -struct ImageEncoderViT { - img_size: usize, - patch_embed: PatchEmbed, - blocks: Vec, - neck_conv1: candle_nn::Conv2d, - neck_ln1: LayerNorm, - neck_conv2: candle_nn::Conv2d, - neck_ln2: LayerNorm, - pos_embed: Option, -} - -impl ImageEncoderViT { - #[allow(clippy::too_many_arguments)] - fn new( - img_size: usize, - patch_size: usize, - in_chans: usize, - embed_dim: usize, - depth: usize, - num_heads: usize, - out_chans: usize, - qkv_bias: bool, - use_rel_pos: bool, - use_abs_pos: bool, - window_size: usize, - vb: VarBuilder, - ) -> Result { - let patch_embed = PatchEmbed::new( - in_chans, - embed_dim, - patch_size, - patch_size, - 0, - vb.pp("patch_embed"), - )?; - let mut blocks = Vec::with_capacity(depth); - let vb_b = vb.pp("blocks"); - for i in 0..depth { - let block = Block::new( - embed_dim, - num_heads, - qkv_bias, - use_rel_pos, - window_size, - vb_b.pp(i), - )?; - blocks.push(block) - } - let neck_conv1 = candle_nn::conv2d_no_bias( - embed_dim, - out_chans, - 1, - Default::default(), - vb.pp("neck.0"), - )?; - let neck_ln1 = layer_norm(out_chans, 1e-6, vb.pp("neck.1"))?; - let cfg = candle_nn::Conv2dConfig { - padding: 1, - ..Default::default() - }; - let neck_conv2 = candle_nn::conv2d_no_bias(out_chans, out_chans, 3, cfg, vb.pp("neck.2"))?; - let neck_ln2 = layer_norm(out_chans, 1e-6, vb.pp("neck.3"))?; - let pos_embed = if use_abs_pos { - let p = vb.get( - (1, img_size / patch_size, img_size / patch_size, embed_dim), - "pos_embed", - )?; - Some(p) - } else { - None - }; - Ok(Self { - img_size, - patch_embed, - blocks, - neck_conv1, - neck_ln1, - neck_conv2, - neck_ln2, - pos_embed, - }) - } -} - -impl Module for ImageEncoderViT { - fn forward(&self, xs: &Tensor) -> Result { - let xs = self.patch_embed.forward(xs)?; - let mut xs = match &self.pos_embed { - Some(pos_embed) => (xs + pos_embed)?, - None => xs, - }; - for block in self.blocks.iter() { - xs = block.forward(&xs)? - } - xs.permute((0, 3, 1, 2))? - .apply(&self.neck_conv1)? - .apply(&self.neck_ln1)? - .apply(&self.neck_conv2)? - .apply(&self.neck_ln2) - } -} - -#[derive(Debug)] -struct MlpMaskDecoder { - layers: Vec, - sigmoid_output: bool, -} - -impl MlpMaskDecoder { - fn new( - input_dim: usize, - hidden_dim: usize, - output_dim: usize, - num_layers: usize, - sigmoid_output: bool, - vb: VarBuilder, - ) -> Result { - let mut layers = Vec::with_capacity(num_layers); - let vb = vb.pp("layers"); - for i in 0..num_layers { - let in_dim = if i == 0 { input_dim } else { hidden_dim }; - let out_dim = if i + 1 == num_layers { - output_dim - } else { - hidden_dim - }; - let layer = linear(vb.pp(i), in_dim, out_dim, true)?; - layers.push(layer) - } - Ok(Self { - layers, - sigmoid_output, - }) - } -} - -impl Module for MlpMaskDecoder { - fn forward(&self, xs: &Tensor) -> Result { - let mut xs = xs.clone(); - for (i, layer) in self.layers.iter().enumerate() { - xs = layer.forward(&xs)?; - if i + 1 < self.layers.len() { - xs = xs.relu()? - } - } - if self.sigmoid_output { - candle_nn::ops::sigmoid(&xs) - } else { - Ok(xs) - } - } -} - -#[derive(Debug)] -struct MaskDecoder { - iou_tokens: candle_nn::Embedding, - mask_tokens: candle_nn::Embedding, - iou_prediction_head: MlpMaskDecoder, -} - -impl MaskDecoder { - fn new( - transformer_dim: usize, - num_multimask_outputs: usize, - iou_head_depth: usize, - iou_head_hidden_dim: usize, - vb: VarBuilder, - ) -> Result { - let num_mask_tokens = num_multimask_outputs - 1; - let iou_prediction_head = MlpMaskDecoder::new( - transformer_dim, - iou_head_hidden_dim, - num_mask_tokens, - iou_head_depth, - false, - vb.pp("iou_prediction_head"), - )?; - let iou_tokens = candle_nn::embedding(1, transformer_dim, vb.pp("iou_tokens"))?; - let mask_tokens = - candle_nn::embedding(num_mask_tokens, transformer_dim, vb.pp("mask_tokens"))?; - Ok(Self { - iou_tokens, - mask_tokens, - iou_prediction_head, - }) - } -} - /* fn interpolate_pos_encoding(&self, xs: &Tensor, w: usize, h: usize) -> Result { let npatch = xs.dim(1)? - 1; diff --git a/candle-examples/examples/segment-anything/model_image_encoder.rs b/candle-examples/examples/segment-anything/model_image_encoder.rs new file mode 100644 index 00000000..c8b6fd7b --- /dev/null +++ b/candle-examples/examples/segment-anything/model_image_encoder.rs @@ -0,0 +1,257 @@ +use candle::{DType, IndexOp, Result, Tensor, D}; +use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder}; + +#[derive(Debug)] +struct PatchEmbed { + proj: candle_nn::Conv2d, +} + +impl PatchEmbed { + fn new( + in_chans: usize, + embed_dim: usize, + k_size: usize, + stride: usize, + padding: usize, + vb: VarBuilder, + ) -> Result { + let cfg = candle_nn::Conv2dConfig { + stride, + padding, + ..Default::default() + }; + let proj = candle_nn::conv2d(in_chans, embed_dim, k_size, cfg, vb.pp("proj"))?; + Ok(Self { proj }) + } +} + +impl Module for PatchEmbed { + fn forward(&self, xs: &Tensor) -> Result { + xs.apply(&self.proj)?.permute((0, 2, 3, 1)) + } +} + +#[derive(Debug)] +struct Attention { + qkv: Linear, + proj: Linear, + num_heads: usize, + scale: f64, + use_rel_pos: bool, + rel_pos_hw: Option<(Tensor, Tensor)>, +} + +impl Attention { + fn new( + dim: usize, + num_heads: usize, + qkv_bias: bool, + use_rel_pos: bool, + window_size: usize, + vb: VarBuilder, + ) -> Result { + let qkv = crate::linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?; + let proj = crate::linear(vb.pp("proj"), dim, dim, true)?; + let head_dim = dim / num_heads; + let scale = 1. / (head_dim as f64).sqrt(); + let rel_pos_hw = if use_rel_pos { + let h = vb.get((2 * window_size - 1, head_dim), "rel_pos_h")?; + let w = vb.get((2 * window_size - 1, head_dim), "rel_pos_w")?; + Some((h, w)) + } else { + None + }; + Ok(Self { + qkv, + proj, + num_heads, + scale, + use_rel_pos, + rel_pos_hw, + }) + } +} + +impl Module for Attention { + fn forward(&self, xs: &Tensor) -> Result { + let (b, h, w, c) = xs.dims4()?; + let qkv = self + .qkv + .forward(xs)? + .reshape((b, h * w, 3, self.num_heads, c / self.num_heads))? + .permute((2, 0, 3, 1, 4))? + .reshape((3, b * self.num_heads, h * w, c / self.num_heads))?; + let q = qkv.i(0)?; + let k = qkv.i(1)?; + let v = qkv.i(2)?; + let attn = (q * self.scale)?.matmul(&k.t()?)?; + if self.use_rel_pos { + todo!() + } + let attn = candle_nn::ops::softmax_last_dim(&attn)?; + let attn = attn + .matmul(&v)? + .reshape((b, self.num_heads, h, w, c / self.num_heads))? + .permute((0, 2, 3, 1, 4))? + .reshape((b, h, w, c / self.num_heads))?; + self.proj.forward(&attn) + } +} + +#[derive(Debug)] +struct Block { + norm1: LayerNorm, + attn: Attention, + norm2: LayerNorm, + mlp: crate::MlpBlock, + window_size: usize, +} + +impl Block { + fn new( + dim: usize, + num_heads: usize, + qkv_bias: bool, + use_rel_pos: bool, + window_size: usize, + vb: VarBuilder, + ) -> Result { + let norm1 = layer_norm(dim, 1e-5, vb.pp("norm1"))?; + let norm2 = layer_norm(dim, 1e-5, vb.pp("norm2"))?; + let attn = Attention::new( + dim, + num_heads, + qkv_bias, + use_rel_pos, + window_size, + vb.pp("attn"), + )?; + let mlp = crate::MlpBlock::new(dim, dim * 4, vb.pp("mlp"))?; + Ok(Self { + norm1, + attn, + norm2, + mlp, + window_size, + }) + } +} + +impl Module for Block { + fn forward(&self, xs: &Tensor) -> Result { + let shortcut = xs; + let xs = self.norm1.forward(xs)?; + if self.window_size > 0 { + todo!() + } + let xs = self.attn.forward(&xs)?; + if self.window_size > 0 { + todo!() + } + let xs = (xs + shortcut)?; + &xs + xs.apply(&self.norm2)?.apply(&self.mlp)? + } +} + +#[derive(Debug)] +struct ImageEncoderViT { + img_size: usize, + patch_embed: PatchEmbed, + blocks: Vec, + neck_conv1: candle_nn::Conv2d, + neck_ln1: LayerNorm, + neck_conv2: candle_nn::Conv2d, + neck_ln2: LayerNorm, + pos_embed: Option, +} + +impl ImageEncoderViT { + #[allow(clippy::too_many_arguments)] + fn new( + img_size: usize, + patch_size: usize, + in_chans: usize, + embed_dim: usize, + depth: usize, + num_heads: usize, + out_chans: usize, + qkv_bias: bool, + use_rel_pos: bool, + use_abs_pos: bool, + window_size: usize, + vb: VarBuilder, + ) -> Result { + let patch_embed = PatchEmbed::new( + in_chans, + embed_dim, + patch_size, + patch_size, + 0, + vb.pp("patch_embed"), + )?; + let mut blocks = Vec::with_capacity(depth); + let vb_b = vb.pp("blocks"); + for i in 0..depth { + let block = Block::new( + embed_dim, + num_heads, + qkv_bias, + use_rel_pos, + window_size, + vb_b.pp(i), + )?; + blocks.push(block) + } + let neck_conv1 = candle_nn::conv2d_no_bias( + embed_dim, + out_chans, + 1, + Default::default(), + vb.pp("neck.0"), + )?; + let neck_ln1 = layer_norm(out_chans, 1e-6, vb.pp("neck.1"))?; + let cfg = candle_nn::Conv2dConfig { + padding: 1, + ..Default::default() + }; + let neck_conv2 = candle_nn::conv2d_no_bias(out_chans, out_chans, 3, cfg, vb.pp("neck.2"))?; + let neck_ln2 = layer_norm(out_chans, 1e-6, vb.pp("neck.3"))?; + let pos_embed = if use_abs_pos { + let p = vb.get( + (1, img_size / patch_size, img_size / patch_size, embed_dim), + "pos_embed", + )?; + Some(p) + } else { + None + }; + Ok(Self { + img_size, + patch_embed, + blocks, + neck_conv1, + neck_ln1, + neck_conv2, + neck_ln2, + pos_embed, + }) + } +} + +impl Module for ImageEncoderViT { + fn forward(&self, xs: &Tensor) -> Result { + let xs = self.patch_embed.forward(xs)?; + let mut xs = match &self.pos_embed { + Some(pos_embed) => (xs + pos_embed)?, + None => xs, + }; + for block in self.blocks.iter() { + xs = block.forward(&xs)? + } + xs.permute((0, 3, 1, 2))? + .apply(&self.neck_conv1)? + .apply(&self.neck_ln1)? + .apply(&self.neck_conv2)? + .apply(&self.neck_ln2) + } +} diff --git a/candle-examples/examples/segment-anything/model_mask_decoder.rs b/candle-examples/examples/segment-anything/model_mask_decoder.rs new file mode 100644 index 00000000..55a006c4 --- /dev/null +++ b/candle-examples/examples/segment-anything/model_mask_decoder.rs @@ -0,0 +1,222 @@ +use candle::{DType, IndexOp, Result, Tensor, D}; +use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder}; + +#[derive(Debug)] +struct MlpMaskDecoder { + layers: Vec, + sigmoid_output: bool, +} + +impl MlpMaskDecoder { + fn new( + input_dim: usize, + hidden_dim: usize, + output_dim: usize, + num_layers: usize, + sigmoid_output: bool, + vb: VarBuilder, + ) -> Result { + let mut layers = Vec::with_capacity(num_layers); + let vb = vb.pp("layers"); + for i in 0..num_layers { + let in_dim = if i == 0 { input_dim } else { hidden_dim }; + let out_dim = if i + 1 == num_layers { + output_dim + } else { + hidden_dim + }; + let layer = crate::linear(vb.pp(i), in_dim, out_dim, true)?; + layers.push(layer) + } + Ok(Self { + layers, + sigmoid_output, + }) + } +} + +impl Module for MlpMaskDecoder { + fn forward(&self, xs: &Tensor) -> Result { + let mut xs = xs.clone(); + for (i, layer) in self.layers.iter().enumerate() { + xs = layer.forward(&xs)?; + if i + 1 < self.layers.len() { + xs = xs.relu()? + } + } + if self.sigmoid_output { + candle_nn::ops::sigmoid(&xs) + } else { + Ok(xs) + } + } +} + +#[derive(Debug)] +struct MaskDecoder { + iou_token: candle_nn::Embedding, + mask_tokens: candle_nn::Embedding, + iou_prediction_head: MlpMaskDecoder, + output_upscaling_conv1: candle_nn::ConvTranspose2d, + output_upscaling_ln: LayerNorm, + output_upscaling_conv2: candle_nn::ConvTranspose2d, + num_mask_tokens: usize, + output_hypernetworks_mlps: Vec, +} + +impl MaskDecoder { + fn new( + transformer_dim: usize, + num_multimask_outputs: usize, + iou_head_depth: usize, + iou_head_hidden_dim: usize, + vb: VarBuilder, + ) -> Result { + let num_mask_tokens = num_multimask_outputs - 1; + let iou_prediction_head = MlpMaskDecoder::new( + transformer_dim, + iou_head_hidden_dim, + num_mask_tokens, + iou_head_depth, + false, + vb.pp("iou_prediction_head"), + )?; + let iou_token = candle_nn::embedding(1, transformer_dim, vb.pp("iou_token"))?; + let mask_tokens = + candle_nn::embedding(num_mask_tokens, transformer_dim, vb.pp("mask_tokens"))?; + let cfg = candle_nn::ConvTranspose2dConfig { + stride: 2, + ..Default::default() + }; + let output_upscaling_conv1 = candle_nn::conv_transpose2d( + transformer_dim, + transformer_dim / 4, + 2, + cfg, + vb.pp("output_upscaling.0"), + )?; + let output_upscaling_ln = + layer_norm(transformer_dim / 4, 1e-6, vb.pp("output_upscaling.1"))?; + let output_upscaling_conv2 = candle_nn::conv_transpose2d( + transformer_dim / 4, + transformer_dim / 8, + 2, + cfg, + vb.pp("output_upscaling.3"), + )?; + let mut output_hypernetworks_mlps = Vec::with_capacity(num_mask_tokens); + let vb_o = vb.pp("output_hypernetworks_mlps"); + for i in 0..num_mask_tokens { + let mlp = MlpMaskDecoder::new( + transformer_dim, + transformer_dim, + transformer_dim / 8, + 3, + false, + vb_o.pp(i), + )?; + output_hypernetworks_mlps.push(mlp) + } + Ok(Self { + iou_token, + mask_tokens, + iou_prediction_head, + output_upscaling_conv1, + output_upscaling_ln, + output_upscaling_conv2, + num_mask_tokens, + output_hypernetworks_mlps, + }) + } + + fn forward( + &self, + image_embeddings: &Tensor, + image_pe: &Tensor, + sparse_prompt_embeddings: &Tensor, + dense_prompt_embeddings: &Tensor, + multimask_output: bool, + ) -> Result<(Tensor, Tensor)> { + let (masks, iou_pred) = self.predict_masks( + image_embeddings, + image_pe, + sparse_prompt_embeddings, + dense_prompt_embeddings, + )?; + let masks = if multimask_output { + masks.i((.., 1..))? + } else { + masks.i((.., 0..1))? + }; + let iou_pred = if multimask_output { + iou_pred.i((.., 1..))? + } else { + iou_pred.i((.., 0..1))? + }; + Ok((masks, iou_pred)) + } + + fn predict_masks( + &self, + image_embeddings: &Tensor, + image_pe: &Tensor, + sparse_prompt_embeddings: &Tensor, + dense_prompt_embeddings: &Tensor, + ) -> Result<(Tensor, Tensor)> { + // Concatenate ouput tokens. + let output_tokens = Tensor::cat( + &[self.iou_token.embeddings(), self.mask_tokens.embeddings()], + 0, + )?; + let (d1, d2) = output_tokens.dims2()?; + let output_tokens = + output_tokens + .unsqueeze(0)? + .expand((sparse_prompt_embeddings.dim(0)?, d1, d2))?; + let tokens = Tensor::cat(&[&output_tokens, sparse_prompt_embeddings], 1)?; + + // Expand per-image data in batch direction to be per mask + let src = repeat_interleave(image_embeddings, tokens.dim(0)?, 0)?; + let src = (src + dense_prompt_embeddings)?; + let pos_src = repeat_interleave(image_pe, tokens.dim(0)?, 0)?; + let (b, c, h, w) = src.dims4()?; + + // Run the transformer + let (hs, src) = run_transformer(&src, &pos_src, &tokens)?; + let iou_token_out = hs.i((.., 0))?; + let mask_tokens_out = hs.i((.., 1, 1 + self.num_mask_tokens))?; + + // Upscale mask embeddings and predict masks using the masks tokens. + let src = src.transpose(1, 2)?.reshape((b, c, h, w))?; + let upscaled_embedding = self + .output_upscaling_conv1 + .forward(&src)? + .apply(&self.output_upscaling_ln)? + .gelu()? + .apply(&self.output_upscaling_conv2)? + .gelu()?; + let mut hyper_in_list = Vec::with_capacity(self.num_mask_tokens); + for (i, mlp) in self.output_hypernetworks_mlps.iter().enumerate() { + let h = mlp.forward(&mask_tokens_out.i((.., i))?)?; + hyper_in_list.push(h) + } + let hyper_in = Tensor::stack(hyper_in_list.as_slice(), 1)?; + let (b, c, h, w) = upscaled_embedding.dims4()?; + let masks = hyper_in + .matmul(&upscaled_embedding.reshape((b, c, h * w))?)? + .reshape((b, 0, h, w))?; + + // Generate mask quality predictions. + let iou_pred = self.iou_prediction_head.forward(&iou_token_out)?; + Ok((masks, iou_pred)) + } +} + +// Equivalent to torch.repeat_interleave +fn repeat_interleave(_img: &Tensor, _repeats: usize, _dim: usize) -> Result { + todo!() +} + +fn run_transformer(_src: &Tensor, _pos: &Tensor, _tokens: &Tensor) -> Result<(Tensor, Tensor)> { + todo!() +} diff --git a/candle-examples/examples/segment-anything/model_transformer.rs b/candle-examples/examples/segment-anything/model_transformer.rs new file mode 100644 index 00000000..10f7f4e5 --- /dev/null +++ b/candle-examples/examples/segment-anything/model_transformer.rs @@ -0,0 +1,77 @@ +use candle::{DType, IndexOp, Result, Tensor, D}; +use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder}; + +#[derive(Debug)] +struct Attention { + q_proj: Linear, + k_proj: Linear, + v_proj: Linear, + out_proj: Linear, + internal_dim: usize, + num_heads: usize, +} + +impl Attention { + fn new( + embedding_dim: usize, + num_heads: usize, + downsample_rate: usize, + vb: VarBuilder, + ) -> Result { + let internal_dim = embedding_dim / downsample_rate; + let q_proj = candle_nn::linear(embedding_dim, internal_dim, vb.pp("q_proj"))?; + let k_proj = candle_nn::linear(embedding_dim, internal_dim, vb.pp("k_proj"))?; + let v_proj = candle_nn::linear(embedding_dim, internal_dim, vb.pp("v_proj"))?; + let out_proj = candle_nn::linear(internal_dim, embedding_dim, vb.pp("out_proj"))?; + Ok(Self { + q_proj, + k_proj, + v_proj, + out_proj, + internal_dim, + num_heads, + }) + } + + fn separate_heads(&self, x: &Tensor) -> Result { + let (b, n, c) = x.dims3()?; + x.reshape((b, n, self.num_heads, c / self.num_heads))? + .transpose(1, 2) + } + + fn recombine_heads(&self, x: &Tensor) -> Result { + let (b, n_heads, n_tokens, c_per_head) = x.dims4()?; + x.transpose(1, 2)? + .reshape((b, n_tokens, n_heads * c_per_head)) + } + + fn forward(&self, q: &Tensor, k: &Tensor, v: &Tensor) -> Result { + let q = self.q_proj.forward(q)?; + let k = self.k_proj.forward(k)?; + let v = self.v_proj.forward(v)?; + + let q = self.separate_heads(&q)?; + let k = self.separate_heads(&k)?; + let v = self.separate_heads(&v)?; + + let (_, _, _, c_per_head) = q.dims4()?; + let attn = (q.matmul(&k.t()?)? / (c_per_head as f64).sqrt())?; + let attn = candle_nn::ops::softmax_last_dim(&attn)?; + + let out = attn.matmul(&v)?; + self.recombine_heads(&out)?.apply(&self.out_proj) + } +} + +#[derive(Debug)] +struct TwoWayAttentionBlock { + self_attn: Attention, + norm1: LayerNorm, + cross_attn_token_to_image: Attention, + norm2: LayerNorm, + mlp: crate::MlpBlock, + norm3: LayerNorm, + norm4: LayerNorm, + cross_attn_image_to_token: Attention, + skip_first_layer_pe: bool, +} diff --git a/candle-nn/src/conv.rs b/candle-nn/src/conv.rs index fe44c153..b2483058 100644 --- a/candle-nn/src/conv.rs +++ b/candle-nn/src/conv.rs @@ -130,6 +130,17 @@ pub struct ConvTranspose2dConfig { // TODO: support groups. } +impl Default for ConvTranspose2dConfig { + fn default() -> Self { + Self { + padding: 0, + output_padding: 0, + stride: 1, + dilation: 1, + } + } +} + #[derive(Debug)] pub struct ConvTranspose2d { weight: Tensor,