diff --git a/candle-examples/examples/segment-anything/main.rs b/candle-examples/examples/segment-anything/main.rs index de16f70c..368b5a33 100644 --- a/candle-examples/examples/segment-anything/main.rs +++ b/candle-examples/examples/segment-anything/main.rs @@ -8,9 +8,11 @@ extern crate intel_mkl_src; #[cfg(feature = "accelerate")] extern crate accelerate_src; -mod model_image_encoder; -mod model_mask_decoder; -mod model_transformer; +pub mod model_image_encoder; +pub mod model_mask_decoder; +pub mod model_prompt_encoder; +pub mod model_sam; +pub mod model_transformer; use candle::{DType, IndexOp, Result, Tensor, D}; use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder}; @@ -82,7 +84,7 @@ impl Module for MlpBlock { #[derive(Parser)] struct Args { #[arg(long)] - model: Option, + model: String, #[arg(long)] image: String, @@ -95,10 +97,15 @@ struct Args { pub fn main() -> anyhow::Result<()> { let args = Args::parse(); - let _device = candle_examples::device(args.cpu)?; + let device = candle_examples::device(args.cpu)?; - let image = candle_examples::imagenet::load_image224(args.image)?; + let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device); println!("loaded image {image:?}"); + let weights = unsafe { candle::safetensors::MmapedFile::new(args.model)? }; + let weights = weights.deserialize()?; + let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device); + let _sam = model_sam::Sam::new(768, 12, 12, &[2, 5, 8, 11], vb)?; // sam_vit_b + Ok(()) } diff --git a/candle-examples/examples/segment-anything/model_image_encoder.rs b/candle-examples/examples/segment-anything/model_image_encoder.rs index c8b6fd7b..cfcdbb38 100644 --- a/candle-examples/examples/segment-anything/model_image_encoder.rs +++ b/candle-examples/examples/segment-anything/model_image_encoder.rs @@ -47,7 +47,7 @@ impl Attention { num_heads: usize, qkv_bias: bool, use_rel_pos: bool, - window_size: usize, + input_size: (usize, usize), vb: VarBuilder, ) -> Result { let qkv = crate::linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?; @@ -55,8 +55,8 @@ impl Attention { let head_dim = dim / num_heads; let scale = 1. / (head_dim as f64).sqrt(); let rel_pos_hw = if use_rel_pos { - let h = vb.get((2 * window_size - 1, head_dim), "rel_pos_h")?; - let w = vb.get((2 * window_size - 1, head_dim), "rel_pos_w")?; + let h = vb.get((2 * input_size.0 - 1, head_dim), "rel_pos_h")?; + let w = vb.get((2 * input_size.1 - 1, head_dim), "rel_pos_w")?; Some((h, w)) } else { None @@ -114,16 +114,22 @@ impl Block { qkv_bias: bool, use_rel_pos: bool, window_size: usize, + input_size: (usize, usize), vb: VarBuilder, ) -> Result { let norm1 = layer_norm(dim, 1e-5, vb.pp("norm1"))?; let norm2 = layer_norm(dim, 1e-5, vb.pp("norm2"))?; + let input_size_attn = if window_size == 0 { + input_size + } else { + (window_size, window_size) + }; let attn = Attention::new( dim, num_heads, qkv_bias, use_rel_pos, - window_size, + input_size_attn, vb.pp("attn"), )?; let mlp = crate::MlpBlock::new(dim, dim * 4, vb.pp("mlp"))?; @@ -154,7 +160,7 @@ impl Module for Block { } #[derive(Debug)] -struct ImageEncoderViT { +pub struct ImageEncoderViT { img_size: usize, patch_embed: PatchEmbed, blocks: Vec, @@ -167,7 +173,7 @@ struct ImageEncoderViT { impl ImageEncoderViT { #[allow(clippy::too_many_arguments)] - fn new( + pub fn new( img_size: usize, patch_size: usize, in_chans: usize, @@ -179,6 +185,7 @@ impl ImageEncoderViT { use_rel_pos: bool, use_abs_pos: bool, window_size: usize, + global_attn_indexes: &[usize], vb: VarBuilder, ) -> Result { let patch_embed = PatchEmbed::new( @@ -192,12 +199,18 @@ impl ImageEncoderViT { let mut blocks = Vec::with_capacity(depth); let vb_b = vb.pp("blocks"); for i in 0..depth { + let window_size = if global_attn_indexes.contains(&i) { + 0 + } else { + window_size + }; let block = Block::new( embed_dim, num_heads, qkv_bias, use_rel_pos, window_size, + (img_size / patch_size, img_size / patch_size), vb_b.pp(i), )?; blocks.push(block) diff --git a/candle-examples/examples/segment-anything/model_mask_decoder.rs b/candle-examples/examples/segment-anything/model_mask_decoder.rs index 55a006c4..cf3879cd 100644 --- a/candle-examples/examples/segment-anything/model_mask_decoder.rs +++ b/candle-examples/examples/segment-anything/model_mask_decoder.rs @@ -1,6 +1,8 @@ use candle::{DType, IndexOp, Result, Tensor, D}; use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder}; +use crate::model_transformer::TwoWayTransformer; + #[derive(Debug)] struct MlpMaskDecoder { layers: Vec, @@ -53,7 +55,7 @@ impl Module for MlpMaskDecoder { } #[derive(Debug)] -struct MaskDecoder { +pub struct MaskDecoder { iou_token: candle_nn::Embedding, mask_tokens: candle_nn::Embedding, iou_prediction_head: MlpMaskDecoder, @@ -62,17 +64,18 @@ struct MaskDecoder { output_upscaling_conv2: candle_nn::ConvTranspose2d, num_mask_tokens: usize, output_hypernetworks_mlps: Vec, + transformer: TwoWayTransformer, } impl MaskDecoder { - fn new( + pub fn new( transformer_dim: usize, num_multimask_outputs: usize, iou_head_depth: usize, iou_head_hidden_dim: usize, vb: VarBuilder, ) -> Result { - let num_mask_tokens = num_multimask_outputs - 1; + let num_mask_tokens = num_multimask_outputs + 1; let iou_prediction_head = MlpMaskDecoder::new( transformer_dim, iou_head_hidden_dim, @@ -117,6 +120,13 @@ impl MaskDecoder { )?; output_hypernetworks_mlps.push(mlp) } + let transformer = TwoWayTransformer::new( + /* depth */ 2, + /* embedding_dim */ transformer_dim, + /* num_heads */ 8, + /* mlp_dim */ 2048, + vb.pp("transformer"), + )?; Ok(Self { iou_token, mask_tokens, @@ -126,6 +136,7 @@ impl MaskDecoder { output_upscaling_conv2, num_mask_tokens, output_hypernetworks_mlps, + transformer, }) } @@ -182,7 +193,7 @@ impl MaskDecoder { let (b, c, h, w) = src.dims4()?; // Run the transformer - let (hs, src) = run_transformer(&src, &pos_src, &tokens)?; + let (hs, src) = self.transformer.forward(&src, &pos_src, &tokens)?; let iou_token_out = hs.i((.., 0))?; let mask_tokens_out = hs.i((.., 1, 1 + self.num_mask_tokens))?; @@ -216,7 +227,3 @@ impl MaskDecoder { fn repeat_interleave(_img: &Tensor, _repeats: usize, _dim: usize) -> Result { todo!() } - -fn run_transformer(_src: &Tensor, _pos: &Tensor, _tokens: &Tensor) -> Result<(Tensor, Tensor)> { - todo!() -} diff --git a/candle-examples/examples/segment-anything/model_prompt_encoder.rs b/candle-examples/examples/segment-anything/model_prompt_encoder.rs new file mode 100644 index 00000000..7ac4c66d --- /dev/null +++ b/candle-examples/examples/segment-anything/model_prompt_encoder.rs @@ -0,0 +1,192 @@ +use candle::{DType, IndexOp, Result, Tensor, D}; +use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder}; + +#[derive(Debug)] +struct PostionEmbeddingRandom { + positional_encoding_gaussian_matrix: Tensor, +} + +impl PostionEmbeddingRandom { + fn new(num_pos_feats: usize, vb: VarBuilder) -> Result { + let positional_encoding_gaussian_matrix = + vb.get((2, num_pos_feats), "positional_encoding_gaussian_matrix")?; + Ok(Self { + positional_encoding_gaussian_matrix, + }) + } + + fn pe_encoding(&self, coords: &Tensor) -> Result { + let coords = coords.affine(2., -1.)?; + let coords = coords.matmul(&self.positional_encoding_gaussian_matrix)?; + let coords = (coords * (2. * std::f64::consts::PI))?; + Tensor::cat(&[coords.sin()?, coords.cos()?], D::Minus1) + } + + fn forward(&self, h: usize, w: usize) -> Result { + let device = self.positional_encoding_gaussian_matrix.device(); + let grid = Tensor::ones((h, w), DType::F32, device)?; + // TODO: cumsum + let x_embed = (&grid - 0.5)?; + // TODO: cumsum + let y_embed = (&grid - 0.5)?; + let x_embed = (x_embed / w as f64)?; + let y_embed = (y_embed / h as f64)?; + let coords = Tensor::stack(&[&x_embed, &y_embed], D::Minus1)?; + self.pe_encoding(&coords)?.permute((2, 0, 1)) + } + + fn forward_with_coords( + &self, + coords_input: &Tensor, + image_size: (usize, usize), + ) -> Result { + let coords0 = (coords_input.narrow(D::Minus1, 0, 1)? / image_size.1 as f64)?; + let coords1 = (coords_input.narrow(D::Minus1, 1, 1)? / image_size.0 as f64)?; + let c = coords_input.dim(D::Minus1)?; + let coords_rest = coords_input.narrow(D::Minus1, 2, c - 2)?; + let coords = Tensor::cat(&[&coords0, &coords1, &coords_rest], D::Minus1)?; + self.pe_encoding(&coords) + } +} + +#[derive(Debug)] +pub struct PromptEncoder { + pe_layer: PostionEmbeddingRandom, + point_embeddings: Vec, + not_a_point_embed: candle_nn::Embedding, + mask_downscaling_conv1: candle_nn::Conv2d, + mask_downscaling_ln1: LayerNorm, + mask_downscaling_conv2: candle_nn::Conv2d, + mask_downscaling_ln2: LayerNorm, + mask_downscaling_conv3: candle_nn::Conv2d, + no_mask_embed: candle_nn::Embedding, + image_embedding_size: (usize, usize), + input_image_size: (usize, usize), +} + +impl PromptEncoder { + pub fn new( + embed_dim: usize, + image_embedding_size: (usize, usize), + input_image_size: (usize, usize), + mask_in_chans: usize, + vb: VarBuilder, + ) -> Result { + let num_points_embeddings = 4; + let pe_layer = PostionEmbeddingRandom::new(embed_dim / 2, vb.pp("pe_layer"))?; + let not_a_point_embed = candle_nn::embedding(1, embed_dim, vb.pp("not_a_point_embed"))?; + let no_mask_embed = candle_nn::embedding(1, embed_dim, vb.pp("no_mask_embed"))?; + let cfg = candle_nn::Conv2dConfig { + stride: 2, + ..Default::default() + }; + let mask_downscaling_conv1 = + candle_nn::conv2d(1, mask_in_chans / 4, 2, cfg, vb.pp("mask_downscaling.0"))?; + let mask_downscaling_conv2 = candle_nn::conv2d( + mask_in_chans / 4, + mask_in_chans, + 2, + cfg, + vb.pp("mask_downscaling.3"), + )?; + let mask_downscaling_conv3 = candle_nn::conv2d( + mask_in_chans, + embed_dim, + 1, + Default::default(), + vb.pp("mask_downscaling.6"), + )?; + let mask_downscaling_ln1 = + layer_norm(mask_in_chans / 4, 1e-6, vb.pp("mask_downscaling.1"))?; + let mask_downscaling_ln2 = layer_norm(mask_in_chans, 1e-6, vb.pp("mask_downscaling.4"))?; + let mut point_embeddings = Vec::with_capacity(num_points_embeddings); + let vb_e = vb.pp("point_embeddings"); + for i in 0..num_points_embeddings { + let emb = candle_nn::embedding(1, embed_dim, vb_e.pp(i))?; + point_embeddings.push(emb) + } + Ok(Self { + pe_layer, + point_embeddings, + not_a_point_embed, + mask_downscaling_conv1, + mask_downscaling_ln1, + mask_downscaling_conv2, + mask_downscaling_ln2, + mask_downscaling_conv3, + no_mask_embed, + image_embedding_size, + input_image_size, + }) + } + + fn embed_masks(&self, masks: &Tensor) -> Result { + masks + .apply(&self.mask_downscaling_conv1)? + .apply(&self.mask_downscaling_ln1)? + .gelu()? + .apply(&self.mask_downscaling_conv2)? + .apply(&self.mask_downscaling_ln2)? + .gelu()? + .apply(&self.mask_downscaling_conv3) + } + + fn embed_points(&self, points: &Tensor, labels: &Tensor, pad: bool) -> Result { + let points = (points + 0.5)?; + let points = if pad { todo!() } else { points }; + let point_embedding = self + .pe_layer + .forward_with_coords(&points, self.input_image_size)?; + // TODO: tweak based on labels. + Ok(point_embedding) + } + + fn embed_boxes(&self, boxes: &Tensor) -> Result { + let boxes = (boxes + 0.5)?; + let coords = boxes.reshape((boxes.elem_count() / 4, 2, 2))?; + let corner_embedding = self + .pe_layer + .forward_with_coords(&coords, self.input_image_size)?; + let ce1 = corner_embedding.i((.., 0))?; + let ce2 = corner_embedding.i((.., 1))?; + let ce1 = (ce1 + self.point_embeddings[2].embeddings())?; + let ce2 = (ce2 + self.point_embeddings[3].embeddings())?; + Tensor::cat(&[&ce1, &ce2], 1) + } + + fn forward( + &self, + points: Option<(&Tensor, &Tensor)>, + boxes: Option<&Tensor>, + masks: Option<&Tensor>, + ) -> Result<(Tensor, Tensor)> { + let se_points = match points { + Some((coords, labels)) => Some(self.embed_points(coords, labels, boxes.is_none())?), + None => None, + }; + let se_boxes = match boxes { + Some(boxes) => Some(self.embed_boxes(boxes)?), + None => None, + }; + let sparse_embeddings = match (se_points, se_boxes) { + (Some(se_points), Some(se_boxes)) => Tensor::cat(&[se_points, se_boxes], 1)?, + (Some(se_points), None) => se_points, + (None, Some(se_boxes)) => se_boxes, + (None, None) => Tensor::zeros(1, DType::F32, &candle::Device::Cpu)?, + }; + + let dense_embeddings = match masks { + None => { + let emb = self.no_mask_embed.embeddings(); + emb.reshape((1, emb.elem_count(), 1, 1))?.expand(( + 1, + 0, + self.image_embedding_size.0, + self.image_embedding_size.1, + ))? + } + Some(masks) => self.embed_masks(masks)?, + }; + Ok((sparse_embeddings, dense_embeddings)) + } +} diff --git a/candle-examples/examples/segment-anything/model_sam.rs b/candle-examples/examples/segment-anything/model_sam.rs new file mode 100644 index 00000000..5a0d7e8f --- /dev/null +++ b/candle-examples/examples/segment-anything/model_sam.rs @@ -0,0 +1,72 @@ +use candle::{DType, IndexOp, Result, Tensor, D}; +use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder}; + +use crate::model_image_encoder::ImageEncoderViT; +use crate::model_mask_decoder::MaskDecoder; +use crate::model_prompt_encoder::PromptEncoder; + +#[derive(Debug)] +pub struct Sam { + image_encoder: ImageEncoderViT, + prompt_encoder: PromptEncoder, + mask_decoder: MaskDecoder, + pixel_mean: Tensor, + pixel_std: Tensor, +} + +impl Sam { + pub fn new( + encoder_embed_dim: usize, + encoder_depth: usize, + encoder_num_heads: usize, + encoder_global_attn_indexes: &[usize], + vb: VarBuilder, + ) -> Result { + const PROMPT_EMBED_DIM: usize = 256; + const IMAGE_SIZE: usize = 1024; + const VIT_PATCH_SIZE: usize = 16; + + let image_embedding_size = IMAGE_SIZE / VIT_PATCH_SIZE; + + let image_encoder = ImageEncoderViT::new( + IMAGE_SIZE, + VIT_PATCH_SIZE, + 3, + encoder_embed_dim, + encoder_depth, + encoder_num_heads, + PROMPT_EMBED_DIM, + /* qkv_bias */ true, + /* use_rel_pos */ true, + /* use_abs_pos */ true, + /* window_size */ 14, + /* global_attn_indexes */ encoder_global_attn_indexes, + vb.pp("image_encoder"), + )?; + let prompt_encoder = PromptEncoder::new( + PROMPT_EMBED_DIM, + (image_embedding_size, image_embedding_size), + (IMAGE_SIZE, IMAGE_SIZE), + 16, + vb.pp("prompt_encoder"), + )?; + let mask_decoder = MaskDecoder::new( + PROMPT_EMBED_DIM, + /* num_multitask_outputs */ 3, + /* iou_head_depth */ 3, + /* iou_head_hidden_dim */ 256, + vb.pp("mask_decoder"), + )?; + let pixel_mean = + Tensor::new(&[123.675f32, 116.28, 103.53], vb.device())?.reshape((3, 1, 1))?; + let pixel_std = + Tensor::new(&[58.395f32, 57.12, 57.375], vb.device())?.reshape((3, 1, 1))?; + Ok(Self { + image_encoder, + prompt_encoder, + mask_decoder, + pixel_std, + pixel_mean, + }) + } +} diff --git a/candle-examples/examples/segment-anything/model_transformer.rs b/candle-examples/examples/segment-anything/model_transformer.rs index 10f7f4e5..a845085d 100644 --- a/candle-examples/examples/segment-anything/model_transformer.rs +++ b/candle-examples/examples/segment-anything/model_transformer.rs @@ -75,3 +75,146 @@ struct TwoWayAttentionBlock { cross_attn_image_to_token: Attention, skip_first_layer_pe: bool, } + +impl TwoWayAttentionBlock { + fn new( + embedding_dim: usize, + num_heads: usize, + mlp_dim: usize, + skip_first_layer_pe: bool, + vb: VarBuilder, + ) -> Result { + let self_attn = Attention::new(embedding_dim, num_heads, 1, vb.pp("self_attn"))?; + let norm1 = layer_norm(embedding_dim, 1e-5, vb.pp("norm1"))?; + let norm2 = layer_norm(embedding_dim, 1e-5, vb.pp("norm2"))?; + let norm3 = layer_norm(embedding_dim, 1e-5, vb.pp("norm3"))?; + let norm4 = layer_norm(embedding_dim, 1e-5, vb.pp("norm4"))?; + let self_attn = Attention::new(embedding_dim, num_heads, 1, vb.pp("self_attn"))?; + let cross_attn_token_to_image = Attention::new( + embedding_dim, + num_heads, + 2, + vb.pp("cross_attn_token_to_image"), + )?; + let cross_attn_image_to_token = Attention::new( + embedding_dim, + num_heads, + 2, + vb.pp("cross_attn_image_to_token"), + )?; + // TODO: use relu in this mlp + let mlp = crate::MlpBlock::new(embedding_dim, mlp_dim, vb.pp("mlp"))?; + Ok(Self { + self_attn, + norm1, + cross_attn_image_to_token, + norm2, + mlp, + norm3, + norm4, + cross_attn_token_to_image, + skip_first_layer_pe, + }) + } + + fn forward( + &self, + queries: &Tensor, + keys: &Tensor, + query_pe: &Tensor, + key_pe: &Tensor, + ) -> Result<(Tensor, Tensor)> { + // Self attention block + let queries = if self.skip_first_layer_pe { + self.self_attn.forward(queries, keys, queries)? + } else { + let q = (queries + query_pe)?; + let attn_out = self.self_attn.forward(&q, &q, queries)?; + (queries + attn_out)? + }; + let queries = self.norm1.forward(&queries)?; + + // Cross attention block, tokens attending to image embedding + let q = (&queries + query_pe)?; + let k = (keys + key_pe)?; + let attn_out = self.cross_attn_token_to_image.forward(&q, &k, keys)?; + let queries = (&queries + attn_out)?; + let queries = self.norm2.forward(&queries)?; + + // MLP block + let mlp_out = self.mlp.forward(&queries); + let queries = (queries + mlp_out)?; + let queries = self.norm3.forward(&queries)?; + + // Cross attention block, image embedding attending to tokens + let q = (&queries + query_pe)?; + let k = (keys + key_pe)?; + let attn_out = self.cross_attn_image_to_token.forward(&k, &q, &queries)?; + let keys = (keys + attn_out)?; + let keys = self.norm4.forward(&keys)?; + + Ok((queries, keys)) + } +} + +#[derive(Debug)] +pub struct TwoWayTransformer { + layers: Vec, + final_attn_token_to_image: Attention, + norm_final_attn: LayerNorm, +} + +impl TwoWayTransformer { + pub fn new( + depth: usize, + embedding_dim: usize, + num_heads: usize, + mlp_dim: usize, + vb: VarBuilder, + ) -> Result { + let vb_l = vb.pp("layers"); + let mut layers = Vec::with_capacity(depth); + for i in 0..depth { + let layer = + TwoWayAttentionBlock::new(embedding_dim, num_heads, mlp_dim, i == 0, vb_l.pp(i))?; + layers.push(layer) + } + let final_attn_token_to_image = Attention::new( + embedding_dim, + num_heads, + 2, + vb.pp("final_attn_token_to_image"), + )?; + let norm_final_attn = layer_norm(embedding_dim, 1e-5, vb.pp("norm_final_attn"))?; + Ok(Self { + layers, + final_attn_token_to_image, + norm_final_attn, + }) + } + + pub fn forward( + &self, + image_embedding: &Tensor, + image_pe: &Tensor, + point_embedding: &Tensor, + ) -> Result<(Tensor, Tensor)> { + let (bs, c, h, w) = image_embedding.dims4()?; + let image_embedding = image_embedding.flatten_from(2)?.permute((0, 2, 1))?; + let image_pe = image_pe.flatten_from(2)?.permute((0, 2, 1))?; + + let mut queries = point_embedding.clone(); + let mut keys = image_embedding; + + for layer in self.layers.iter() { + (queries, keys) = layer.forward(&queries, &keys, point_embedding, &image_pe)? + } + + let q = (&queries + point_embedding)?; + let k = (&keys + image_pe)?; + let attn_out = self.final_attn_token_to_image.forward(&q, &k, &keys)?; + let queries = (queries + attn_out)?.apply(&self.norm_final_attn)?; + + Ok((queries, keys)) + } +}