Automatic mask generator + point base mask (#773)

* Add more to the automatic mask generator.

* Add the target point.

* Fix.

* Remove the allow-unused.

* Mask post-processing.
This commit is contained in:
Laurent Mazare
2023-09-08 12:26:56 +01:00
committed by GitHub
parent c1453f00b1
commit 28c87f6a34
7 changed files with 249 additions and 42 deletions

View File

@ -1,6 +1,5 @@
//! SAM: Segment Anything Model
//! https://github.com/facebookresearch/segment-anything
#![allow(unused)]
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
@ -14,7 +13,7 @@ pub mod model_prompt_encoder;
pub mod model_sam;
pub mod model_transformer;
use candle::{DType, IndexOp, Result, Tensor, D};
use candle::{DType, Result, Tensor};
use candle_nn::{Linear, Module, VarBuilder};
use clap::Parser;
@ -101,6 +100,15 @@ struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
#[arg(long)]
generate_masks: bool,
#[arg(long)]
point_x: Option<f64>,
#[arg(long)]
point_y: Option<f64>,
}
pub fn main() -> anyhow::Result<()> {
@ -108,7 +116,7 @@ pub fn main() -> anyhow::Result<()> {
let device = candle_examples::device(args.cpu)?;
let image = if args.image.ends_with(".safetensors") {
let (image, initial_h, initial_w) = if args.image.ends_with(".safetensors") {
let mut tensors = candle::safetensors::load(&args.image, &device)?;
let image = match tensors.remove("image") {
Some(image) => image,
@ -119,13 +127,16 @@ pub fn main() -> anyhow::Result<()> {
tensors.into_values().next().unwrap()
}
};
if image.rank() == 4 {
let image = if image.rank() == 4 {
image.get(0)?
} else {
image
}
};
let (_c, h, w) = image.dims3()?;
(image, h, w)
} else {
candle_examples::load_image(args.image, Some(model_sam::IMAGE_SIZE))?.to_device(&device)?
let (image, h, w) = candle_examples::load_image(args.image, Some(model_sam::IMAGE_SIZE))?;
(image.to_device(&device)?, h, w)
};
println!("loaded image {image:?}");
@ -142,19 +153,30 @@ pub fn main() -> anyhow::Result<()> {
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
let sam = model_sam::Sam::new(768, 12, 12, &[2, 5, 8, 11], vb)?; // sam_vit_b
let (mask, iou_predictions) = sam.forward(&image, false)?;
println!("mask:\n{mask}");
println!("iou_predictions: {iou_predictions:?}");
if args.generate_masks {
// Default options similar to the Python version.
sam.generate_masks(
&image,
/* points_per_side */ 32,
/* crop_n_layer */ 0,
/* crop_overlap_ratio */ 512. / 1500.,
/* crop_n_points_downscale_factor */ 1,
)?
} else {
let point = args.point_x.zip(args.point_y);
let (mask, iou_predictions) = sam.forward(&image, point, false)?;
println!("mask:\n{mask}");
println!("iou_predictions: {iou_predictions:?}");
// Save the mask as an image.
let mask = mask.ge(&mask.zeros_like()?)?;
let mask = (mask * 255.)?.squeeze(0)?;
let (_one, h, w) = mask.dims3()?;
let mask = mask.expand((3, h, w))?;
candle_examples::save_image(&mask, "sam_mask.png")?;
// Save the mask as an image.
let mask = (mask.ge(&mask.zeros_like()?)? * 255.)?;
let (_one, h, w) = mask.dims3()?;
let mask = mask.expand((3, h, w))?;
candle_examples::save_image_resize(&mask, "sam_mask.png", initial_h, initial_w)?;
let image = sam.preprocess(&image)?;
let image = sam.unpreprocess(&image)?.to_dtype(DType::U8)?;
candle_examples::save_image(&image, "sam_input_scaled.png")?;
let image = sam.preprocess(&image)?;
let image = sam.unpreprocess(&image)?.to_dtype(DType::U8)?;
candle_examples::save_image(&image, "sam_input_scaled.png")?;
}
Ok(())
}

View File

@ -1,4 +1,4 @@
use candle::{DType, IndexOp, Result, Tensor, D};
use candle::{DType, IndexOp, Result, Tensor};
use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
#[derive(Debug)]
@ -37,7 +37,6 @@ struct Attention {
proj: Linear,
num_heads: usize,
scale: f64,
use_rel_pos: bool,
rel_pos_hw: Option<(Tensor, Tensor)>,
}
@ -66,7 +65,6 @@ impl Attention {
proj,
num_heads,
scale,
use_rel_pos,
rel_pos_hw,
})
}
@ -272,7 +270,6 @@ impl Module for Block {
#[derive(Debug)]
pub struct ImageEncoderViT {
img_size: usize,
patch_embed: PatchEmbed,
blocks: Vec<Block>,
neck_conv1: candle_nn::Conv2d,
@ -350,7 +347,6 @@ impl ImageEncoderViT {
None
};
Ok(Self {
img_size,
patch_embed,
blocks,
neck_conv1,

View File

@ -1,4 +1,4 @@
use candle::{DType, IndexOp, Result, Tensor, D};
use candle::{IndexOp, Result, Tensor};
use candle_nn::{Linear, Module, VarBuilder};
use crate::model_transformer::TwoWayTransformer;
@ -188,7 +188,7 @@ impl MaskDecoder {
// Expand per-image data in batch direction to be per mask
let src = repeat_interleave(image_embeddings, tokens.dim(0)?, 0)?;
let src = (src + dense_prompt_embeddings)?;
let src = src.broadcast_add(dense_prompt_embeddings)?;
let pos_src = repeat_interleave(image_pe, tokens.dim(0)?, 0)?;
let (b, c, h, w) = src.dims4()?;

View File

@ -1,5 +1,5 @@
use candle::{DType, IndexOp, Result, Tensor, D};
use candle_nn::{Linear, Module, VarBuilder};
use candle_nn::VarBuilder;
#[derive(Debug)]
struct PostionEmbeddingRandom {
@ -24,7 +24,6 @@ impl PostionEmbeddingRandom {
fn forward(&self, h: usize, w: usize) -> Result<Tensor> {
let device = self.positional_encoding_gaussian_matrix.device();
let grid = Tensor::ones((h, w), DType::F32, device)?;
let x_embed = (Tensor::arange(0u32, w as u32, device)?.to_dtype(DType::F32)? + 0.5)?;
let y_embed = (Tensor::arange(0u32, h as u32, device)?.to_dtype(DType::F32)? + 0.5)?;
let x_embed = (x_embed / w as f64)?
@ -157,8 +156,9 @@ impl PromptEncoder {
let point_embedding = self
.pe_layer
.forward_with_coords(&points, self.input_image_size)?;
let labels = labels.unsqueeze(2)?.broadcast_as(point_embedding.shape())?;
let zeros = point_embedding.zeros_like()?;
let point_embeddings = labels.lt(&labels.zeros_like()?)?.where_cond(
let point_embedding = labels.lt(&labels.zeros_like()?)?.where_cond(
&self
.not_a_point_embed
.embeddings()

View File

@ -1,5 +1,5 @@
use candle::{DType, IndexOp, Result, Tensor, D};
use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
use candle::{DType, IndexOp, Result, Tensor};
use candle_nn::{Module, VarBuilder};
use crate::model_image_encoder::ImageEncoderViT;
use crate::model_mask_decoder::MaskDecoder;
@ -70,12 +70,30 @@ impl Sam {
})
}
pub fn forward(&self, img: &Tensor, multimask_output: bool) -> Result<(Tensor, Tensor)> {
pub fn forward(
&self,
img: &Tensor,
point: Option<(f64, f64)>,
multimask_output: bool,
) -> Result<(Tensor, Tensor)> {
let (_c, original_h, original_w) = img.dims3()?;
let img = self.preprocess(img)?.unsqueeze(0)?;
let img_embeddings = self.image_encoder.forward(&img)?;
let image_pe = self.prompt_encoder.get_dense_pe()?;
let points = match point {
None => None,
Some((x, y)) => {
let points = Tensor::new(
&[[[x as f32 * original_w as f32, y as f32 * original_h as f32]]],
img.device(),
)?;
let labels = Tensor::ones((1, 1), DType::F32, img.device())?;
Some((points, labels))
}
};
let points = points.as_ref().map(|(x, y)| (x, y));
let (sparse_prompt_embeddings, dense_prompt_embeddings) =
self.prompt_encoder.forward(None, None, None)?;
self.prompt_encoder.forward(points, None, None)?;
let (low_res_mask, iou_predictions) = self.mask_decoder.forward(
&img_embeddings,
&image_pe,
@ -83,8 +101,11 @@ impl Sam {
&dense_prompt_embeddings,
multimask_output,
)?;
// TODO: post-processing.
Ok((low_res_mask, iou_predictions))
let mask = low_res_mask
.upsample_nearest2d(IMAGE_SIZE, IMAGE_SIZE)?
.get(0)?
.i((.., ..original_h, ..original_w))?;
Ok((mask, iou_predictions))
}
pub fn unpreprocess(&self, img: &Tensor) -> Result<Tensor> {
@ -96,7 +117,7 @@ impl Sam {
}
pub fn preprocess(&self, img: &Tensor) -> Result<Tensor> {
let (c, h, w) = img.dims3()?;
let (_c, h, w) = img.dims3()?;
let img = img
.to_dtype(DType::F32)?
.broadcast_sub(&self.pixel_mean)?
@ -107,4 +128,150 @@ impl Sam {
let img = img.pad_with_zeros(1, 0, IMAGE_SIZE - h)?;
img.pad_with_zeros(2, 0, IMAGE_SIZE - w)
}
fn process_crop(&self, img: &Tensor, cb: CropBox, point_grids: &[(f64, f64)]) -> Result<()> {
// Crop the image and calculate embeddings.
let img = img.i((.., cb.y0..cb.y1, cb.x0..cb.x1))?;
let img = self.preprocess(&img)?.unsqueeze(0)?;
let img_embeddings = self.image_encoder.forward(&img)?;
let crop_w = cb.x1 - cb.x0;
let crop_h = cb.y1 - cb.y0;
// Generate masks for this crop.
let image_pe = self.prompt_encoder.get_dense_pe()?;
let points = point_grids
.iter()
.map(|&(x, y)| vec![x as f32 * crop_w as f32, y as f32 * crop_h as f32])
.collect::<Vec<_>>();
for points in points.chunks(64) {
let points_len = points.len();
let in_points = Tensor::new(points.to_vec(), img.device())?.unsqueeze(1)?;
let in_labels = Tensor::ones((points_len, 1), DType::F32, img.device())?;
let (sparse_prompt_embeddings, dense_prompt_embeddings) =
self.prompt_encoder
.forward(Some((&in_points, &in_labels)), None, None)?;
let (_low_res_mask, iou_predictions) = self.mask_decoder.forward(
&img_embeddings,
&image_pe,
&sparse_prompt_embeddings,
&dense_prompt_embeddings,
/* multimask_output */ true,
)?;
println!("{cb:?} {iou_predictions}");
}
// Remove duplicates within this crop.
// Return to the original image frame.
Ok(())
}
pub fn generate_masks(
&self,
img: &Tensor,
points_per_side: usize,
crop_n_layer: usize,
crop_overlap_ratio: f64,
crop_n_points_downscale_factor: usize,
) -> Result<()> {
let (_c, h, w) = img.dims3()?;
let point_grids = build_all_layer_point_grids(
points_per_side,
crop_n_layer,
crop_n_points_downscale_factor,
);
let crop_boxes = generate_crop_boxes((h, w), crop_n_layer, crop_overlap_ratio);
for crop_box in crop_boxes.into_iter() {
let layer_idx = crop_box.layer_idx;
self.process_crop(img, crop_box, &point_grids[layer_idx])?
}
// TODO: remove duplicates
Ok(())
}
}
#[derive(Debug)]
struct CropBox {
x0: usize,
y0: usize,
x1: usize,
y1: usize,
layer_idx: usize,
}
impl CropBox {
fn new(x0: usize, y0: usize, x1: usize, y1: usize, layer_idx: usize) -> Self {
Self {
x0,
y0,
x1,
y1,
layer_idx,
}
}
}
fn generate_crop_boxes(
(im_h, im_w): (usize, usize),
n_layers: usize,
overlap_ratio: f64,
) -> Vec<CropBox> {
fn crop_len(orig_len: usize, n_crops: usize, overlap: usize) -> usize {
f64::ceil((overlap * (n_crops - 1) + orig_len) as f64 / n_crops as f64) as usize
}
let short_side = usize::min(im_h, im_w);
let mut crop_boxes = Vec::new();
// Original image.
crop_boxes.push(CropBox::new(0, 0, im_w, im_h, 0));
for layer_idx in 1..=n_layers {
let n_crops_per_side = 1 << layer_idx;
let overlap = (overlap_ratio * short_side as f64 * 2. / n_crops_per_side as f64) as usize;
let crop_w = crop_len(im_w, n_crops_per_side, overlap);
let crop_h = crop_len(im_w, n_crops_per_side, overlap);
for i_x in 0..n_crops_per_side {
let x0 = (crop_w - overlap) * i_x;
for i_y in 0..n_crops_per_side {
let y0 = (crop_h - overlap) * i_y;
let x1 = usize::min(im_w, x0 + crop_w);
let y1 = usize::min(im_h, y0 + crop_h);
crop_boxes.push(CropBox::new(x0, y0, x1, y1, layer_idx));
}
}
}
crop_boxes
}
// Generates a 2D grid of points evenly spaced in [0,1]x[0,1].
fn build_point_grid(n_per_side: usize) -> Vec<(f64, f64)> {
let offset = 1f64 / (2 * n_per_side) as f64;
let mut points = Vec::with_capacity(n_per_side * n_per_side);
for i_x in 0..n_per_side {
let x = offset + i_x as f64 / n_per_side as f64;
for i_y in 0..n_per_side {
let y = offset + i_y as f64 / n_per_side as f64;
points.push((x, y))
}
}
points
}
fn build_all_layer_point_grids(
n_per_side: usize,
n_layers: usize,
scale_per_layer: usize,
) -> Vec<Vec<(f64, f64)>> {
let mut points_by_layer = Vec::with_capacity(n_layers + 1);
for i in 0..=n_layers {
let n_points = n_per_side / scale_per_layer.pow(i as u32);
points_by_layer.push(build_point_grid(n_points))
}
points_by_layer
}

View File

@ -1,4 +1,4 @@
use candle::{DType, IndexOp, Result, Tensor, D};
use candle::{Result, Tensor};
use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
#[derive(Debug)]
@ -7,7 +7,6 @@ struct Attention {
k_proj: Linear,
v_proj: Linear,
out_proj: Linear,
internal_dim: usize,
num_heads: usize,
}
@ -28,7 +27,6 @@ impl Attention {
k_proj,
v_proj,
out_proj,
internal_dim,
num_heads,
})
}
@ -85,7 +83,6 @@ impl TwoWayAttentionBlock {
skip_first_layer_pe: bool,
vb: VarBuilder,
) -> Result<Self> {
let self_attn = Attention::new(embedding_dim, num_heads, 1, vb.pp("self_attn"))?;
let norm1 = layer_norm(embedding_dim, 1e-5, vb.pp("norm1"))?;
let norm2 = layer_norm(embedding_dim, 1e-5, vb.pp("norm2"))?;
let norm3 = layer_norm(embedding_dim, 1e-5, vb.pp("norm3"))?;
@ -204,7 +201,6 @@ impl TwoWayTransformer {
image_pe: &Tensor,
point_embedding: &Tensor,
) -> Result<(Tensor, Tensor)> {
let (bs, c, h, w) = image_embedding.dims4()?;
let image_embedding = image_embedding.flatten_from(2)?.permute((0, 2, 1))?;
let image_pe = image_pe.flatten_from(2)?.permute((0, 2, 1))?;

View File

@ -19,10 +19,11 @@ pub fn device(cpu: bool) -> Result<Device> {
pub fn load_image<P: AsRef<std::path::Path>>(
p: P,
resize_longest: Option<usize>,
) -> Result<Tensor> {
) -> Result<(Tensor, usize, usize)> {
let img = image::io::Reader::open(p)?
.decode()
.map_err(candle::Error::wrap)?;
let (initial_h, initial_w) = (img.height() as usize, img.width() as usize);
let img = match resize_longest {
None => img,
Some(resize_longest) => {
@ -41,7 +42,8 @@ pub fn load_image<P: AsRef<std::path::Path>>(
let (height, width) = (img.height() as usize, img.width() as usize);
let img = img.to_rgb8();
let data = img.into_raw();
Tensor::from_vec(data, (height, width, 3), &Device::Cpu)?.permute((2, 0, 1))
let data = Tensor::from_vec(data, (height, width, 3), &Device::Cpu)?.permute((2, 0, 1))?;
Ok((data, initial_h, initial_w))
}
pub fn load_image_and_resize<P: AsRef<std::path::Path>>(
@ -80,3 +82,27 @@ pub fn save_image<P: AsRef<std::path::Path>>(img: &Tensor, p: P) -> Result<()> {
image.save(p).map_err(candle::Error::wrap)?;
Ok(())
}
pub fn save_image_resize<P: AsRef<std::path::Path>>(
img: &Tensor,
p: P,
h: usize,
w: usize,
) -> Result<()> {
let p = p.as_ref();
let (channel, height, width) = img.dims3()?;
if channel != 3 {
candle::bail!("save_image expects an input of shape (3, height, width)")
}
let img = img.permute((1, 2, 0))?.flatten_all()?;
let pixels = img.to_vec1::<u8>()?;
let image: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
match image::ImageBuffer::from_raw(width as u32, height as u32, pixels) {
Some(image) => image,
None => candle::bail!("error saving image {p:?}"),
};
let image = image::DynamicImage::from(image);
let image = image.resize_to_fill(w as u32, h as u32, image::imageops::FilterType::CatmullRom);
image.save(p).map_err(candle::Error::wrap)?;
Ok(())
}