diff --git a/candle-examples/examples/segment-anything/main.rs b/candle-examples/examples/segment-anything/main.rs index c53c1010..0f0c0482 100644 --- a/candle-examples/examples/segment-anything/main.rs +++ b/candle-examples/examples/segment-anything/main.rs @@ -1,6 +1,5 @@ //! SAM: Segment Anything Model //! https://github.com/facebookresearch/segment-anything -#![allow(unused)] #[cfg(feature = "mkl")] extern crate intel_mkl_src; @@ -14,7 +13,7 @@ pub mod model_prompt_encoder; pub mod model_sam; pub mod model_transformer; -use candle::{DType, IndexOp, Result, Tensor, D}; +use candle::{DType, Result, Tensor}; use candle_nn::{Linear, Module, VarBuilder}; use clap::Parser; @@ -101,6 +100,15 @@ struct Args { /// Run on CPU rather than on GPU. #[arg(long)] cpu: bool, + + #[arg(long)] + generate_masks: bool, + + #[arg(long)] + point_x: Option, + + #[arg(long)] + point_y: Option, } pub fn main() -> anyhow::Result<()> { @@ -108,7 +116,7 @@ pub fn main() -> anyhow::Result<()> { let device = candle_examples::device(args.cpu)?; - let image = if args.image.ends_with(".safetensors") { + let (image, initial_h, initial_w) = if args.image.ends_with(".safetensors") { let mut tensors = candle::safetensors::load(&args.image, &device)?; let image = match tensors.remove("image") { Some(image) => image, @@ -119,13 +127,16 @@ pub fn main() -> anyhow::Result<()> { tensors.into_values().next().unwrap() } }; - if image.rank() == 4 { + let image = if image.rank() == 4 { image.get(0)? } else { image - } + }; + let (_c, h, w) = image.dims3()?; + (image, h, w) } else { - candle_examples::load_image(args.image, Some(model_sam::IMAGE_SIZE))?.to_device(&device)? + let (image, h, w) = candle_examples::load_image(args.image, Some(model_sam::IMAGE_SIZE))?; + (image.to_device(&device)?, h, w) }; println!("loaded image {image:?}"); @@ -142,19 +153,30 @@ pub fn main() -> anyhow::Result<()> { let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device); let sam = model_sam::Sam::new(768, 12, 12, &[2, 5, 8, 11], vb)?; // sam_vit_b - let (mask, iou_predictions) = sam.forward(&image, false)?; - println!("mask:\n{mask}"); - println!("iou_predictions: {iou_predictions:?}"); + if args.generate_masks { + // Default options similar to the Python version. + sam.generate_masks( + &image, + /* points_per_side */ 32, + /* crop_n_layer */ 0, + /* crop_overlap_ratio */ 512. / 1500., + /* crop_n_points_downscale_factor */ 1, + )? + } else { + let point = args.point_x.zip(args.point_y); + let (mask, iou_predictions) = sam.forward(&image, point, false)?; + println!("mask:\n{mask}"); + println!("iou_predictions: {iou_predictions:?}"); - // Save the mask as an image. - let mask = mask.ge(&mask.zeros_like()?)?; - let mask = (mask * 255.)?.squeeze(0)?; - let (_one, h, w) = mask.dims3()?; - let mask = mask.expand((3, h, w))?; - candle_examples::save_image(&mask, "sam_mask.png")?; + // Save the mask as an image. + let mask = (mask.ge(&mask.zeros_like()?)? * 255.)?; + let (_one, h, w) = mask.dims3()?; + let mask = mask.expand((3, h, w))?; + candle_examples::save_image_resize(&mask, "sam_mask.png", initial_h, initial_w)?; - let image = sam.preprocess(&image)?; - let image = sam.unpreprocess(&image)?.to_dtype(DType::U8)?; - candle_examples::save_image(&image, "sam_input_scaled.png")?; + let image = sam.preprocess(&image)?; + let image = sam.unpreprocess(&image)?.to_dtype(DType::U8)?; + candle_examples::save_image(&image, "sam_input_scaled.png")?; + } Ok(()) } diff --git a/candle-examples/examples/segment-anything/model_image_encoder.rs b/candle-examples/examples/segment-anything/model_image_encoder.rs index 79e52d47..f1b76e23 100644 --- a/candle-examples/examples/segment-anything/model_image_encoder.rs +++ b/candle-examples/examples/segment-anything/model_image_encoder.rs @@ -1,4 +1,4 @@ -use candle::{DType, IndexOp, Result, Tensor, D}; +use candle::{DType, IndexOp, Result, Tensor}; use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder}; #[derive(Debug)] @@ -37,7 +37,6 @@ struct Attention { proj: Linear, num_heads: usize, scale: f64, - use_rel_pos: bool, rel_pos_hw: Option<(Tensor, Tensor)>, } @@ -66,7 +65,6 @@ impl Attention { proj, num_heads, scale, - use_rel_pos, rel_pos_hw, }) } @@ -272,7 +270,6 @@ impl Module for Block { #[derive(Debug)] pub struct ImageEncoderViT { - img_size: usize, patch_embed: PatchEmbed, blocks: Vec, neck_conv1: candle_nn::Conv2d, @@ -350,7 +347,6 @@ impl ImageEncoderViT { None }; Ok(Self { - img_size, patch_embed, blocks, neck_conv1, diff --git a/candle-examples/examples/segment-anything/model_mask_decoder.rs b/candle-examples/examples/segment-anything/model_mask_decoder.rs index acbfeeea..598af1f6 100644 --- a/candle-examples/examples/segment-anything/model_mask_decoder.rs +++ b/candle-examples/examples/segment-anything/model_mask_decoder.rs @@ -1,4 +1,4 @@ -use candle::{DType, IndexOp, Result, Tensor, D}; +use candle::{IndexOp, Result, Tensor}; use candle_nn::{Linear, Module, VarBuilder}; use crate::model_transformer::TwoWayTransformer; @@ -188,7 +188,7 @@ impl MaskDecoder { // Expand per-image data in batch direction to be per mask let src = repeat_interleave(image_embeddings, tokens.dim(0)?, 0)?; - let src = (src + dense_prompt_embeddings)?; + let src = src.broadcast_add(dense_prompt_embeddings)?; let pos_src = repeat_interleave(image_pe, tokens.dim(0)?, 0)?; let (b, c, h, w) = src.dims4()?; diff --git a/candle-examples/examples/segment-anything/model_prompt_encoder.rs b/candle-examples/examples/segment-anything/model_prompt_encoder.rs index e4291ebb..b401a900 100644 --- a/candle-examples/examples/segment-anything/model_prompt_encoder.rs +++ b/candle-examples/examples/segment-anything/model_prompt_encoder.rs @@ -1,5 +1,5 @@ use candle::{DType, IndexOp, Result, Tensor, D}; -use candle_nn::{Linear, Module, VarBuilder}; +use candle_nn::VarBuilder; #[derive(Debug)] struct PostionEmbeddingRandom { @@ -24,7 +24,6 @@ impl PostionEmbeddingRandom { fn forward(&self, h: usize, w: usize) -> Result { let device = self.positional_encoding_gaussian_matrix.device(); - let grid = Tensor::ones((h, w), DType::F32, device)?; let x_embed = (Tensor::arange(0u32, w as u32, device)?.to_dtype(DType::F32)? + 0.5)?; let y_embed = (Tensor::arange(0u32, h as u32, device)?.to_dtype(DType::F32)? + 0.5)?; let x_embed = (x_embed / w as f64)? @@ -157,8 +156,9 @@ impl PromptEncoder { let point_embedding = self .pe_layer .forward_with_coords(&points, self.input_image_size)?; + let labels = labels.unsqueeze(2)?.broadcast_as(point_embedding.shape())?; let zeros = point_embedding.zeros_like()?; - let point_embeddings = labels.lt(&labels.zeros_like()?)?.where_cond( + let point_embedding = labels.lt(&labels.zeros_like()?)?.where_cond( &self .not_a_point_embed .embeddings() diff --git a/candle-examples/examples/segment-anything/model_sam.rs b/candle-examples/examples/segment-anything/model_sam.rs index 237163a3..884559af 100644 --- a/candle-examples/examples/segment-anything/model_sam.rs +++ b/candle-examples/examples/segment-anything/model_sam.rs @@ -1,5 +1,5 @@ -use candle::{DType, IndexOp, Result, Tensor, D}; -use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder}; +use candle::{DType, IndexOp, Result, Tensor}; +use candle_nn::{Module, VarBuilder}; use crate::model_image_encoder::ImageEncoderViT; use crate::model_mask_decoder::MaskDecoder; @@ -70,12 +70,30 @@ impl Sam { }) } - pub fn forward(&self, img: &Tensor, multimask_output: bool) -> Result<(Tensor, Tensor)> { + pub fn forward( + &self, + img: &Tensor, + point: Option<(f64, f64)>, + multimask_output: bool, + ) -> Result<(Tensor, Tensor)> { + let (_c, original_h, original_w) = img.dims3()?; let img = self.preprocess(img)?.unsqueeze(0)?; let img_embeddings = self.image_encoder.forward(&img)?; let image_pe = self.prompt_encoder.get_dense_pe()?; + let points = match point { + None => None, + Some((x, y)) => { + let points = Tensor::new( + &[[[x as f32 * original_w as f32, y as f32 * original_h as f32]]], + img.device(), + )?; + let labels = Tensor::ones((1, 1), DType::F32, img.device())?; + Some((points, labels)) + } + }; + let points = points.as_ref().map(|(x, y)| (x, y)); let (sparse_prompt_embeddings, dense_prompt_embeddings) = - self.prompt_encoder.forward(None, None, None)?; + self.prompt_encoder.forward(points, None, None)?; let (low_res_mask, iou_predictions) = self.mask_decoder.forward( &img_embeddings, &image_pe, @@ -83,8 +101,11 @@ impl Sam { &dense_prompt_embeddings, multimask_output, )?; - // TODO: post-processing. - Ok((low_res_mask, iou_predictions)) + let mask = low_res_mask + .upsample_nearest2d(IMAGE_SIZE, IMAGE_SIZE)? + .get(0)? + .i((.., ..original_h, ..original_w))?; + Ok((mask, iou_predictions)) } pub fn unpreprocess(&self, img: &Tensor) -> Result { @@ -96,7 +117,7 @@ impl Sam { } pub fn preprocess(&self, img: &Tensor) -> Result { - let (c, h, w) = img.dims3()?; + let (_c, h, w) = img.dims3()?; let img = img .to_dtype(DType::F32)? .broadcast_sub(&self.pixel_mean)? @@ -107,4 +128,150 @@ impl Sam { let img = img.pad_with_zeros(1, 0, IMAGE_SIZE - h)?; img.pad_with_zeros(2, 0, IMAGE_SIZE - w) } + + fn process_crop(&self, img: &Tensor, cb: CropBox, point_grids: &[(f64, f64)]) -> Result<()> { + // Crop the image and calculate embeddings. + let img = img.i((.., cb.y0..cb.y1, cb.x0..cb.x1))?; + let img = self.preprocess(&img)?.unsqueeze(0)?; + let img_embeddings = self.image_encoder.forward(&img)?; + + let crop_w = cb.x1 - cb.x0; + let crop_h = cb.y1 - cb.y0; + + // Generate masks for this crop. + let image_pe = self.prompt_encoder.get_dense_pe()?; + let points = point_grids + .iter() + .map(|&(x, y)| vec![x as f32 * crop_w as f32, y as f32 * crop_h as f32]) + .collect::>(); + for points in points.chunks(64) { + let points_len = points.len(); + let in_points = Tensor::new(points.to_vec(), img.device())?.unsqueeze(1)?; + let in_labels = Tensor::ones((points_len, 1), DType::F32, img.device())?; + let (sparse_prompt_embeddings, dense_prompt_embeddings) = + self.prompt_encoder + .forward(Some((&in_points, &in_labels)), None, None)?; + let (_low_res_mask, iou_predictions) = self.mask_decoder.forward( + &img_embeddings, + &image_pe, + &sparse_prompt_embeddings, + &dense_prompt_embeddings, + /* multimask_output */ true, + )?; + + println!("{cb:?} {iou_predictions}"); + } + + // Remove duplicates within this crop. + + // Return to the original image frame. + Ok(()) + } + + pub fn generate_masks( + &self, + img: &Tensor, + points_per_side: usize, + crop_n_layer: usize, + crop_overlap_ratio: f64, + crop_n_points_downscale_factor: usize, + ) -> Result<()> { + let (_c, h, w) = img.dims3()?; + let point_grids = build_all_layer_point_grids( + points_per_side, + crop_n_layer, + crop_n_points_downscale_factor, + ); + let crop_boxes = generate_crop_boxes((h, w), crop_n_layer, crop_overlap_ratio); + for crop_box in crop_boxes.into_iter() { + let layer_idx = crop_box.layer_idx; + self.process_crop(img, crop_box, &point_grids[layer_idx])? + } + // TODO: remove duplicates + Ok(()) + } +} + +#[derive(Debug)] +struct CropBox { + x0: usize, + y0: usize, + x1: usize, + y1: usize, + layer_idx: usize, +} + +impl CropBox { + fn new(x0: usize, y0: usize, x1: usize, y1: usize, layer_idx: usize) -> Self { + Self { + x0, + y0, + x1, + y1, + layer_idx, + } + } +} + +fn generate_crop_boxes( + (im_h, im_w): (usize, usize), + n_layers: usize, + overlap_ratio: f64, +) -> Vec { + fn crop_len(orig_len: usize, n_crops: usize, overlap: usize) -> usize { + f64::ceil((overlap * (n_crops - 1) + orig_len) as f64 / n_crops as f64) as usize + } + + let short_side = usize::min(im_h, im_w); + + let mut crop_boxes = Vec::new(); + + // Original image. + crop_boxes.push(CropBox::new(0, 0, im_w, im_h, 0)); + + for layer_idx in 1..=n_layers { + let n_crops_per_side = 1 << layer_idx; + let overlap = (overlap_ratio * short_side as f64 * 2. / n_crops_per_side as f64) as usize; + let crop_w = crop_len(im_w, n_crops_per_side, overlap); + let crop_h = crop_len(im_w, n_crops_per_side, overlap); + + for i_x in 0..n_crops_per_side { + let x0 = (crop_w - overlap) * i_x; + for i_y in 0..n_crops_per_side { + let y0 = (crop_h - overlap) * i_y; + let x1 = usize::min(im_w, x0 + crop_w); + let y1 = usize::min(im_h, y0 + crop_h); + crop_boxes.push(CropBox::new(x0, y0, x1, y1, layer_idx)); + } + } + } + + crop_boxes +} + +// Generates a 2D grid of points evenly spaced in [0,1]x[0,1]. +fn build_point_grid(n_per_side: usize) -> Vec<(f64, f64)> { + let offset = 1f64 / (2 * n_per_side) as f64; + let mut points = Vec::with_capacity(n_per_side * n_per_side); + for i_x in 0..n_per_side { + let x = offset + i_x as f64 / n_per_side as f64; + for i_y in 0..n_per_side { + let y = offset + i_y as f64 / n_per_side as f64; + points.push((x, y)) + } + } + points +} + +fn build_all_layer_point_grids( + n_per_side: usize, + n_layers: usize, + scale_per_layer: usize, +) -> Vec> { + let mut points_by_layer = Vec::with_capacity(n_layers + 1); + for i in 0..=n_layers { + let n_points = n_per_side / scale_per_layer.pow(i as u32); + points_by_layer.push(build_point_grid(n_points)) + } + points_by_layer } diff --git a/candle-examples/examples/segment-anything/model_transformer.rs b/candle-examples/examples/segment-anything/model_transformer.rs index 044dce9b..e4de27cb 100644 --- a/candle-examples/examples/segment-anything/model_transformer.rs +++ b/candle-examples/examples/segment-anything/model_transformer.rs @@ -1,4 +1,4 @@ -use candle::{DType, IndexOp, Result, Tensor, D}; +use candle::{Result, Tensor}; use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder}; #[derive(Debug)] @@ -7,7 +7,6 @@ struct Attention { k_proj: Linear, v_proj: Linear, out_proj: Linear, - internal_dim: usize, num_heads: usize, } @@ -28,7 +27,6 @@ impl Attention { k_proj, v_proj, out_proj, - internal_dim, num_heads, }) } @@ -85,7 +83,6 @@ impl TwoWayAttentionBlock { skip_first_layer_pe: bool, vb: VarBuilder, ) -> Result { - let self_attn = Attention::new(embedding_dim, num_heads, 1, vb.pp("self_attn"))?; let norm1 = layer_norm(embedding_dim, 1e-5, vb.pp("norm1"))?; let norm2 = layer_norm(embedding_dim, 1e-5, vb.pp("norm2"))?; let norm3 = layer_norm(embedding_dim, 1e-5, vb.pp("norm3"))?; @@ -204,7 +201,6 @@ impl TwoWayTransformer { image_pe: &Tensor, point_embedding: &Tensor, ) -> Result<(Tensor, Tensor)> { - let (bs, c, h, w) = image_embedding.dims4()?; let image_embedding = image_embedding.flatten_from(2)?.permute((0, 2, 1))?; let image_pe = image_pe.flatten_from(2)?.permute((0, 2, 1))?; diff --git a/candle-examples/src/lib.rs b/candle-examples/src/lib.rs index 66cd2f99..c14b2d6b 100644 --- a/candle-examples/src/lib.rs +++ b/candle-examples/src/lib.rs @@ -19,10 +19,11 @@ pub fn device(cpu: bool) -> Result { pub fn load_image>( p: P, resize_longest: Option, -) -> Result { +) -> Result<(Tensor, usize, usize)> { let img = image::io::Reader::open(p)? .decode() .map_err(candle::Error::wrap)?; + let (initial_h, initial_w) = (img.height() as usize, img.width() as usize); let img = match resize_longest { None => img, Some(resize_longest) => { @@ -41,7 +42,8 @@ pub fn load_image>( let (height, width) = (img.height() as usize, img.width() as usize); let img = img.to_rgb8(); let data = img.into_raw(); - Tensor::from_vec(data, (height, width, 3), &Device::Cpu)?.permute((2, 0, 1)) + let data = Tensor::from_vec(data, (height, width, 3), &Device::Cpu)?.permute((2, 0, 1))?; + Ok((data, initial_h, initial_w)) } pub fn load_image_and_resize>( @@ -80,3 +82,27 @@ pub fn save_image>(img: &Tensor, p: P) -> Result<()> { image.save(p).map_err(candle::Error::wrap)?; Ok(()) } + +pub fn save_image_resize>( + img: &Tensor, + p: P, + h: usize, + w: usize, +) -> Result<()> { + let p = p.as_ref(); + let (channel, height, width) = img.dims3()?; + if channel != 3 { + candle::bail!("save_image expects an input of shape (3, height, width)") + } + let img = img.permute((1, 2, 0))?.flatten_all()?; + let pixels = img.to_vec1::()?; + let image: image::ImageBuffer, Vec> = + match image::ImageBuffer::from_raw(width as u32, height as u32, pixels) { + Some(image) => image, + None => candle::bail!("error saving image {p:?}"), + }; + let image = image::DynamicImage::from(image); + let image = image.resize_to_fill(w as u32, h as u32, image::imageops::FilterType::CatmullRom); + image.save(p).map_err(candle::Error::wrap)?; + Ok(()) +}