From 6527ab81a3a6f26dae37f6c56eecf0f4bb826f02 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Thu, 7 Sep 2023 06:34:05 +0200 Subject: [PATCH] Sketch the segment anything model. (#759) * Sketch the segment anything model. * Fix some clippy lint. * Add the mask decoder. --- candle-core/src/shape.rs | 11 + .../examples/segment-anything/main.rs | 446 ++++++++++++++++++ 2 files changed, 457 insertions(+) create mode 100644 candle-examples/examples/segment-anything/main.rs diff --git a/candle-core/src/shape.rs b/candle-core/src/shape.rs index db0fe98a..578e8ac9 100644 --- a/candle-core/src/shape.rs +++ b/candle-core/src/shape.rs @@ -425,6 +425,17 @@ impl Dims for (D1, D2, D3, D4) { } } +impl Dims for (D1, D2, D3, D4, D5) { + fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result> { + let d0 = self.0.to_index(shape, op)?; + let d1 = self.1.to_index(shape, op)?; + let d2 = self.2.to_index(shape, op)?; + let d3 = self.3.to_index(shape, op)?; + let d4 = self.4.to_index(shape, op)?; + Ok(vec![d0, d1, d2, d3, d4]) + } +} + extract_dims!(dims0, 0, |_: &[usize]| (), ()); extract_dims!(dims1, 1, |d: &[usize]| d[0], usize); extract_dims!(dims2, 2, |d: &[usize]| (d[0], d[1]), (usize, usize)); diff --git a/candle-examples/examples/segment-anything/main.rs b/candle-examples/examples/segment-anything/main.rs new file mode 100644 index 00000000..a53cff8b --- /dev/null +++ b/candle-examples/examples/segment-anything/main.rs @@ -0,0 +1,446 @@ +//! SAM: Segment Anything Model +//! https://github.com/facebookresearch/segment-anything +#![allow(unused)] + +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use clap::Parser; + +use candle::{DType, IndexOp, Result, Tensor, D}; +use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder}; + +const IMG_SIZE: usize = 518; +const PATCH_SIZE: usize = 14; +const NUM_CLASSES: usize = 1000; + +fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result { + if bias { + candle_nn::linear(in_dim, out_dim, vb) + } else { + candle_nn::linear_no_bias(in_dim, out_dim, vb) + } +} + +#[derive(Debug)] +struct MlpBlock { + lin1: Linear, + lin2: Linear, +} + +impl MlpBlock { + fn new(embedding_dim: usize, mlp_dim: usize, vb: VarBuilder) -> Result { + let lin1 = candle_nn::linear(embedding_dim, mlp_dim, vb.pp("lin1"))?; + let lin2 = candle_nn::linear(mlp_dim, embedding_dim, vb.pp("lin2"))?; + Ok(Self { lin1, lin2 }) + } +} + +impl Module for MlpBlock { + fn forward(&self, xs: &Tensor) -> Result { + xs.apply(&self.lin1)?.gelu()?.apply(&self.lin2) + } +} + +#[derive(Debug)] +struct PatchEmbed { + proj: candle_nn::Conv2d, +} + +impl PatchEmbed { + fn new( + in_chans: usize, + embed_dim: usize, + k_size: usize, + stride: usize, + padding: usize, + vb: VarBuilder, + ) -> Result { + let cfg = candle_nn::Conv2dConfig { + stride, + padding, + ..Default::default() + }; + let proj = candle_nn::conv2d(in_chans, embed_dim, k_size, cfg, vb.pp("proj"))?; + Ok(Self { proj }) + } +} + +impl Module for PatchEmbed { + fn forward(&self, xs: &Tensor) -> Result { + xs.apply(&self.proj)?.permute((0, 2, 3, 1)) + } +} + +#[derive(Debug)] +struct Attention { + qkv: Linear, + proj: Linear, + num_heads: usize, + scale: f64, + use_rel_pos: bool, + rel_pos_hw: Option<(Tensor, Tensor)>, +} + +impl Attention { + fn new( + dim: usize, + num_heads: usize, + qkv_bias: bool, + use_rel_pos: bool, + window_size: usize, + vb: VarBuilder, + ) -> Result { + let qkv = linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?; + let proj = linear(vb.pp("proj"), dim, dim, true)?; + let head_dim = dim / num_heads; + let scale = 1. / (head_dim as f64).sqrt(); + let rel_pos_hw = if use_rel_pos { + let h = vb.get((2 * window_size - 1, head_dim), "rel_pos_h")?; + let w = vb.get((2 * window_size - 1, head_dim), "rel_pos_w")?; + Some((h, w)) + } else { + None + }; + Ok(Self { + qkv, + proj, + num_heads, + scale, + use_rel_pos, + rel_pos_hw, + }) + } +} + +impl Module for Attention { + fn forward(&self, xs: &Tensor) -> Result { + let (b, h, w, c) = xs.dims4()?; + let qkv = self + .qkv + .forward(xs)? + .reshape((b, h * w, 3, self.num_heads, c / self.num_heads))? + .permute((2, 0, 3, 1, 4))? + .reshape((3, b * self.num_heads, h * w, c / self.num_heads))?; + let q = qkv.i(0)?; + let k = qkv.i(1)?; + let v = qkv.i(2)?; + let attn = (q * self.scale)?.matmul(&k.t()?)?; + if self.use_rel_pos { + todo!() + } + let attn = candle_nn::ops::softmax_last_dim(&attn)?; + let attn = attn + .matmul(&v)? + .reshape((b, self.num_heads, h, w, c / self.num_heads))? + .permute((0, 2, 3, 1, 4))? + .reshape((b, h, w, c / self.num_heads))?; + self.proj.forward(&attn) + } +} + +#[derive(Debug)] +struct Block { + norm1: LayerNorm, + attn: Attention, + norm2: LayerNorm, + mlp: MlpBlock, + window_size: usize, +} + +impl Block { + fn new( + dim: usize, + num_heads: usize, + qkv_bias: bool, + use_rel_pos: bool, + window_size: usize, + vb: VarBuilder, + ) -> Result { + let norm1 = layer_norm(dim, 1e-5, vb.pp("norm1"))?; + let norm2 = layer_norm(dim, 1e-5, vb.pp("norm2"))?; + let attn = Attention::new( + dim, + num_heads, + qkv_bias, + use_rel_pos, + window_size, + vb.pp("attn"), + )?; + let mlp = MlpBlock::new(dim, dim * 4, vb.pp("mlp"))?; + Ok(Self { + norm1, + attn, + norm2, + mlp, + window_size, + }) + } +} + +impl Module for Block { + fn forward(&self, xs: &Tensor) -> Result { + let shortcut = xs; + let xs = self.norm1.forward(xs)?; + if self.window_size > 0 { + todo!() + } + let xs = self.attn.forward(&xs)?; + if self.window_size > 0 { + todo!() + } + let xs = (xs + shortcut)?; + &xs + xs.apply(&self.norm2)?.apply(&self.mlp)? + } +} + +#[derive(Debug)] +struct ImageEncoderViT { + img_size: usize, + patch_embed: PatchEmbed, + blocks: Vec, + neck_conv1: candle_nn::Conv2d, + neck_ln1: LayerNorm, + neck_conv2: candle_nn::Conv2d, + neck_ln2: LayerNorm, + pos_embed: Option, +} + +impl ImageEncoderViT { + #[allow(clippy::too_many_arguments)] + fn new( + img_size: usize, + patch_size: usize, + in_chans: usize, + embed_dim: usize, + depth: usize, + num_heads: usize, + out_chans: usize, + qkv_bias: bool, + use_rel_pos: bool, + use_abs_pos: bool, + window_size: usize, + vb: VarBuilder, + ) -> Result { + let patch_embed = PatchEmbed::new( + in_chans, + embed_dim, + patch_size, + patch_size, + 0, + vb.pp("patch_embed"), + )?; + let mut blocks = Vec::with_capacity(depth); + let vb_b = vb.pp("blocks"); + for i in 0..depth { + let block = Block::new( + embed_dim, + num_heads, + qkv_bias, + use_rel_pos, + window_size, + vb_b.pp(i), + )?; + blocks.push(block) + } + let neck_conv1 = candle_nn::conv2d_no_bias( + embed_dim, + out_chans, + 1, + Default::default(), + vb.pp("neck.0"), + )?; + let neck_ln1 = layer_norm(out_chans, 1e-6, vb.pp("neck.1"))?; + let cfg = candle_nn::Conv2dConfig { + padding: 1, + ..Default::default() + }; + let neck_conv2 = candle_nn::conv2d_no_bias(out_chans, out_chans, 3, cfg, vb.pp("neck.2"))?; + let neck_ln2 = layer_norm(out_chans, 1e-6, vb.pp("neck.3"))?; + let pos_embed = if use_abs_pos { + let p = vb.get( + (1, img_size / patch_size, img_size / patch_size, embed_dim), + "pos_embed", + )?; + Some(p) + } else { + None + }; + Ok(Self { + img_size, + patch_embed, + blocks, + neck_conv1, + neck_ln1, + neck_conv2, + neck_ln2, + pos_embed, + }) + } +} + +impl Module for ImageEncoderViT { + fn forward(&self, xs: &Tensor) -> Result { + let xs = self.patch_embed.forward(xs)?; + let mut xs = match &self.pos_embed { + Some(pos_embed) => (xs + pos_embed)?, + None => xs, + }; + for block in self.blocks.iter() { + xs = block.forward(&xs)? + } + xs.permute((0, 3, 1, 2))? + .apply(&self.neck_conv1)? + .apply(&self.neck_ln1)? + .apply(&self.neck_conv2)? + .apply(&self.neck_ln2) + } +} + +#[derive(Debug)] +struct MlpMaskDecoder { + layers: Vec, + sigmoid_output: bool, +} + +impl MlpMaskDecoder { + fn new( + input_dim: usize, + hidden_dim: usize, + output_dim: usize, + num_layers: usize, + sigmoid_output: bool, + vb: VarBuilder, + ) -> Result { + let mut layers = Vec::with_capacity(num_layers); + let vb = vb.pp("layers"); + for i in 0..num_layers { + let in_dim = if i == 0 { input_dim } else { hidden_dim }; + let out_dim = if i + 1 == num_layers { + output_dim + } else { + hidden_dim + }; + let layer = linear(vb.pp(i), in_dim, out_dim, true)?; + layers.push(layer) + } + Ok(Self { + layers, + sigmoid_output, + }) + } +} + +impl Module for MlpMaskDecoder { + fn forward(&self, xs: &Tensor) -> Result { + let mut xs = xs.clone(); + for (i, layer) in self.layers.iter().enumerate() { + xs = layer.forward(&xs)?; + if i + 1 < self.layers.len() { + xs = xs.relu()? + } + } + if self.sigmoid_output { + candle_nn::ops::sigmoid(&xs) + } else { + Ok(xs) + } + } +} + +#[derive(Debug)] +struct MaskDecoder { + iou_tokens: candle_nn::Embedding, + mask_tokens: candle_nn::Embedding, + iou_prediction_head: MlpMaskDecoder, +} + +impl MaskDecoder { + fn new( + transformer_dim: usize, + num_multimask_outputs: usize, + iou_head_depth: usize, + iou_head_hidden_dim: usize, + vb: VarBuilder, + ) -> Result { + let num_mask_tokens = num_multimask_outputs - 1; + let iou_prediction_head = MlpMaskDecoder::new( + transformer_dim, + iou_head_hidden_dim, + num_mask_tokens, + iou_head_depth, + false, + vb.pp("iou_prediction_head"), + )?; + let iou_tokens = candle_nn::embedding(1, transformer_dim, vb.pp("iou_tokens"))?; + let mask_tokens = + candle_nn::embedding(num_mask_tokens, transformer_dim, vb.pp("mask_tokens"))?; + Ok(Self { + iou_tokens, + mask_tokens, + iou_prediction_head, + }) + } +} + +/* + fn interpolate_pos_encoding(&self, xs: &Tensor, w: usize, h: usize) -> Result { + let npatch = xs.dim(1)? - 1; + let n = self.pos_embed.dim(1)? - 1; + let sqrt_n = (n as f64).sqrt(); + if npatch == n && w == h { + return Ok(xs.clone()); + } + let class_pos_embed = self.pos_embed.i((.., ..1))?; + let patch_pos_embed = self.pos_embed.i((.., 1..))?; + let dim = xs.dim(D::Minus1)?; + let (w0, h0) = ((w / PATCH_SIZE) as f64 + 0.1, (h / PATCH_SIZE) as f64 + 0.1); + let patch_pos_embed = patch_pos_embed + .reshape((1, sqrt_n as usize, sqrt_n as usize, dim))? + .transpose(2, 3)? + .transpose(1, 2)?; + // This uses bicubic interpolation in the original implementation. + let patch_pos_embed = patch_pos_embed.upsample_nearest2d(h0 as usize, w0 as usize)?; + let el_count = patch_pos_embed.shape().elem_count(); + let patch_pos_embed = + patch_pos_embed + .transpose(1, 2)? + .transpose(2, 3)? + .reshape((1, el_count / dim, dim))?; + Tensor::cat(&[&class_pos_embed, &patch_pos_embed], 1) + } + + fn prepare_tokens_with_mask(&self, xs: &Tensor) -> Result { + let (_b, _nc, w, h) = xs.dims4()?; + let xs = self.patch_embed.forward(xs)?; + let xs = Tensor::cat(&[&self.cls_token, &xs], 1)?; + &xs + &self.interpolate_pos_encoding(&xs, w, h)? + } +*/ + +#[derive(Parser)] +struct Args { + #[arg(long)] + model: Option, + + #[arg(long)] + image: String, + + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, +} + +pub fn main() -> anyhow::Result<()> { + let args = Args::parse(); + + let _device = candle_examples::device(args.cpu)?; + + let image = candle_examples::imagenet::load_image224(args.image)?; + println!("loaded image {image:?}"); + + Ok(()) +}