mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Automatic mask generator + point base mask (#773)
* Add more to the automatic mask generator. * Add the target point. * Fix. * Remove the allow-unused. * Mask post-processing.
This commit is contained in:
@ -1,6 +1,5 @@
|
||||
//! SAM: Segment Anything Model
|
||||
//! https://github.com/facebookresearch/segment-anything
|
||||
#![allow(unused)]
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
@ -14,7 +13,7 @@ pub mod model_prompt_encoder;
|
||||
pub mod model_sam;
|
||||
pub mod model_transformer;
|
||||
|
||||
use candle::{DType, IndexOp, Result, Tensor, D};
|
||||
use candle::{DType, Result, Tensor};
|
||||
use candle_nn::{Linear, Module, VarBuilder};
|
||||
use clap::Parser;
|
||||
|
||||
@ -101,6 +100,15 @@ struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
#[arg(long)]
|
||||
generate_masks: bool,
|
||||
|
||||
#[arg(long)]
|
||||
point_x: Option<f64>,
|
||||
|
||||
#[arg(long)]
|
||||
point_y: Option<f64>,
|
||||
}
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
@ -108,7 +116,7 @@ pub fn main() -> anyhow::Result<()> {
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let image = if args.image.ends_with(".safetensors") {
|
||||
let (image, initial_h, initial_w) = if args.image.ends_with(".safetensors") {
|
||||
let mut tensors = candle::safetensors::load(&args.image, &device)?;
|
||||
let image = match tensors.remove("image") {
|
||||
Some(image) => image,
|
||||
@ -119,13 +127,16 @@ pub fn main() -> anyhow::Result<()> {
|
||||
tensors.into_values().next().unwrap()
|
||||
}
|
||||
};
|
||||
if image.rank() == 4 {
|
||||
let image = if image.rank() == 4 {
|
||||
image.get(0)?
|
||||
} else {
|
||||
image
|
||||
}
|
||||
};
|
||||
let (_c, h, w) = image.dims3()?;
|
||||
(image, h, w)
|
||||
} else {
|
||||
candle_examples::load_image(args.image, Some(model_sam::IMAGE_SIZE))?.to_device(&device)?
|
||||
let (image, h, w) = candle_examples::load_image(args.image, Some(model_sam::IMAGE_SIZE))?;
|
||||
(image.to_device(&device)?, h, w)
|
||||
};
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
@ -142,19 +153,30 @@ pub fn main() -> anyhow::Result<()> {
|
||||
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
|
||||
let sam = model_sam::Sam::new(768, 12, 12, &[2, 5, 8, 11], vb)?; // sam_vit_b
|
||||
|
||||
let (mask, iou_predictions) = sam.forward(&image, false)?;
|
||||
if args.generate_masks {
|
||||
// Default options similar to the Python version.
|
||||
sam.generate_masks(
|
||||
&image,
|
||||
/* points_per_side */ 32,
|
||||
/* crop_n_layer */ 0,
|
||||
/* crop_overlap_ratio */ 512. / 1500.,
|
||||
/* crop_n_points_downscale_factor */ 1,
|
||||
)?
|
||||
} else {
|
||||
let point = args.point_x.zip(args.point_y);
|
||||
let (mask, iou_predictions) = sam.forward(&image, point, false)?;
|
||||
println!("mask:\n{mask}");
|
||||
println!("iou_predictions: {iou_predictions:?}");
|
||||
|
||||
// Save the mask as an image.
|
||||
let mask = mask.ge(&mask.zeros_like()?)?;
|
||||
let mask = (mask * 255.)?.squeeze(0)?;
|
||||
let mask = (mask.ge(&mask.zeros_like()?)? * 255.)?;
|
||||
let (_one, h, w) = mask.dims3()?;
|
||||
let mask = mask.expand((3, h, w))?;
|
||||
candle_examples::save_image(&mask, "sam_mask.png")?;
|
||||
candle_examples::save_image_resize(&mask, "sam_mask.png", initial_h, initial_w)?;
|
||||
|
||||
let image = sam.preprocess(&image)?;
|
||||
let image = sam.unpreprocess(&image)?.to_dtype(DType::U8)?;
|
||||
candle_examples::save_image(&image, "sam_input_scaled.png")?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
use candle::{DType, IndexOp, Result, Tensor, D};
|
||||
use candle::{DType, IndexOp, Result, Tensor};
|
||||
use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
|
||||
|
||||
#[derive(Debug)]
|
||||
@ -37,7 +37,6 @@ struct Attention {
|
||||
proj: Linear,
|
||||
num_heads: usize,
|
||||
scale: f64,
|
||||
use_rel_pos: bool,
|
||||
rel_pos_hw: Option<(Tensor, Tensor)>,
|
||||
}
|
||||
|
||||
@ -66,7 +65,6 @@ impl Attention {
|
||||
proj,
|
||||
num_heads,
|
||||
scale,
|
||||
use_rel_pos,
|
||||
rel_pos_hw,
|
||||
})
|
||||
}
|
||||
@ -272,7 +270,6 @@ impl Module for Block {
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ImageEncoderViT {
|
||||
img_size: usize,
|
||||
patch_embed: PatchEmbed,
|
||||
blocks: Vec<Block>,
|
||||
neck_conv1: candle_nn::Conv2d,
|
||||
@ -350,7 +347,6 @@ impl ImageEncoderViT {
|
||||
None
|
||||
};
|
||||
Ok(Self {
|
||||
img_size,
|
||||
patch_embed,
|
||||
blocks,
|
||||
neck_conv1,
|
||||
|
@ -1,4 +1,4 @@
|
||||
use candle::{DType, IndexOp, Result, Tensor, D};
|
||||
use candle::{IndexOp, Result, Tensor};
|
||||
use candle_nn::{Linear, Module, VarBuilder};
|
||||
|
||||
use crate::model_transformer::TwoWayTransformer;
|
||||
@ -188,7 +188,7 @@ impl MaskDecoder {
|
||||
|
||||
// Expand per-image data in batch direction to be per mask
|
||||
let src = repeat_interleave(image_embeddings, tokens.dim(0)?, 0)?;
|
||||
let src = (src + dense_prompt_embeddings)?;
|
||||
let src = src.broadcast_add(dense_prompt_embeddings)?;
|
||||
let pos_src = repeat_interleave(image_pe, tokens.dim(0)?, 0)?;
|
||||
let (b, c, h, w) = src.dims4()?;
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
use candle::{DType, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{Linear, Module, VarBuilder};
|
||||
use candle_nn::VarBuilder;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct PostionEmbeddingRandom {
|
||||
@ -24,7 +24,6 @@ impl PostionEmbeddingRandom {
|
||||
|
||||
fn forward(&self, h: usize, w: usize) -> Result<Tensor> {
|
||||
let device = self.positional_encoding_gaussian_matrix.device();
|
||||
let grid = Tensor::ones((h, w), DType::F32, device)?;
|
||||
let x_embed = (Tensor::arange(0u32, w as u32, device)?.to_dtype(DType::F32)? + 0.5)?;
|
||||
let y_embed = (Tensor::arange(0u32, h as u32, device)?.to_dtype(DType::F32)? + 0.5)?;
|
||||
let x_embed = (x_embed / w as f64)?
|
||||
@ -157,8 +156,9 @@ impl PromptEncoder {
|
||||
let point_embedding = self
|
||||
.pe_layer
|
||||
.forward_with_coords(&points, self.input_image_size)?;
|
||||
let labels = labels.unsqueeze(2)?.broadcast_as(point_embedding.shape())?;
|
||||
let zeros = point_embedding.zeros_like()?;
|
||||
let point_embeddings = labels.lt(&labels.zeros_like()?)?.where_cond(
|
||||
let point_embedding = labels.lt(&labels.zeros_like()?)?.where_cond(
|
||||
&self
|
||||
.not_a_point_embed
|
||||
.embeddings()
|
||||
|
@ -1,5 +1,5 @@
|
||||
use candle::{DType, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
|
||||
use candle::{DType, IndexOp, Result, Tensor};
|
||||
use candle_nn::{Module, VarBuilder};
|
||||
|
||||
use crate::model_image_encoder::ImageEncoderViT;
|
||||
use crate::model_mask_decoder::MaskDecoder;
|
||||
@ -70,12 +70,30 @@ impl Sam {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, img: &Tensor, multimask_output: bool) -> Result<(Tensor, Tensor)> {
|
||||
pub fn forward(
|
||||
&self,
|
||||
img: &Tensor,
|
||||
point: Option<(f64, f64)>,
|
||||
multimask_output: bool,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
let (_c, original_h, original_w) = img.dims3()?;
|
||||
let img = self.preprocess(img)?.unsqueeze(0)?;
|
||||
let img_embeddings = self.image_encoder.forward(&img)?;
|
||||
let image_pe = self.prompt_encoder.get_dense_pe()?;
|
||||
let points = match point {
|
||||
None => None,
|
||||
Some((x, y)) => {
|
||||
let points = Tensor::new(
|
||||
&[[[x as f32 * original_w as f32, y as f32 * original_h as f32]]],
|
||||
img.device(),
|
||||
)?;
|
||||
let labels = Tensor::ones((1, 1), DType::F32, img.device())?;
|
||||
Some((points, labels))
|
||||
}
|
||||
};
|
||||
let points = points.as_ref().map(|(x, y)| (x, y));
|
||||
let (sparse_prompt_embeddings, dense_prompt_embeddings) =
|
||||
self.prompt_encoder.forward(None, None, None)?;
|
||||
self.prompt_encoder.forward(points, None, None)?;
|
||||
let (low_res_mask, iou_predictions) = self.mask_decoder.forward(
|
||||
&img_embeddings,
|
||||
&image_pe,
|
||||
@ -83,8 +101,11 @@ impl Sam {
|
||||
&dense_prompt_embeddings,
|
||||
multimask_output,
|
||||
)?;
|
||||
// TODO: post-processing.
|
||||
Ok((low_res_mask, iou_predictions))
|
||||
let mask = low_res_mask
|
||||
.upsample_nearest2d(IMAGE_SIZE, IMAGE_SIZE)?
|
||||
.get(0)?
|
||||
.i((.., ..original_h, ..original_w))?;
|
||||
Ok((mask, iou_predictions))
|
||||
}
|
||||
|
||||
pub fn unpreprocess(&self, img: &Tensor) -> Result<Tensor> {
|
||||
@ -96,7 +117,7 @@ impl Sam {
|
||||
}
|
||||
|
||||
pub fn preprocess(&self, img: &Tensor) -> Result<Tensor> {
|
||||
let (c, h, w) = img.dims3()?;
|
||||
let (_c, h, w) = img.dims3()?;
|
||||
let img = img
|
||||
.to_dtype(DType::F32)?
|
||||
.broadcast_sub(&self.pixel_mean)?
|
||||
@ -107,4 +128,150 @@ impl Sam {
|
||||
let img = img.pad_with_zeros(1, 0, IMAGE_SIZE - h)?;
|
||||
img.pad_with_zeros(2, 0, IMAGE_SIZE - w)
|
||||
}
|
||||
|
||||
fn process_crop(&self, img: &Tensor, cb: CropBox, point_grids: &[(f64, f64)]) -> Result<()> {
|
||||
// Crop the image and calculate embeddings.
|
||||
let img = img.i((.., cb.y0..cb.y1, cb.x0..cb.x1))?;
|
||||
let img = self.preprocess(&img)?.unsqueeze(0)?;
|
||||
let img_embeddings = self.image_encoder.forward(&img)?;
|
||||
|
||||
let crop_w = cb.x1 - cb.x0;
|
||||
let crop_h = cb.y1 - cb.y0;
|
||||
|
||||
// Generate masks for this crop.
|
||||
let image_pe = self.prompt_encoder.get_dense_pe()?;
|
||||
let points = point_grids
|
||||
.iter()
|
||||
.map(|&(x, y)| vec![x as f32 * crop_w as f32, y as f32 * crop_h as f32])
|
||||
.collect::<Vec<_>>();
|
||||
for points in points.chunks(64) {
|
||||
let points_len = points.len();
|
||||
let in_points = Tensor::new(points.to_vec(), img.device())?.unsqueeze(1)?;
|
||||
let in_labels = Tensor::ones((points_len, 1), DType::F32, img.device())?;
|
||||
let (sparse_prompt_embeddings, dense_prompt_embeddings) =
|
||||
self.prompt_encoder
|
||||
.forward(Some((&in_points, &in_labels)), None, None)?;
|
||||
let (_low_res_mask, iou_predictions) = self.mask_decoder.forward(
|
||||
&img_embeddings,
|
||||
&image_pe,
|
||||
&sparse_prompt_embeddings,
|
||||
&dense_prompt_embeddings,
|
||||
/* multimask_output */ true,
|
||||
)?;
|
||||
|
||||
println!("{cb:?} {iou_predictions}");
|
||||
}
|
||||
|
||||
// Remove duplicates within this crop.
|
||||
|
||||
// Return to the original image frame.
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn generate_masks(
|
||||
&self,
|
||||
img: &Tensor,
|
||||
points_per_side: usize,
|
||||
crop_n_layer: usize,
|
||||
crop_overlap_ratio: f64,
|
||||
crop_n_points_downscale_factor: usize,
|
||||
) -> Result<()> {
|
||||
let (_c, h, w) = img.dims3()?;
|
||||
let point_grids = build_all_layer_point_grids(
|
||||
points_per_side,
|
||||
crop_n_layer,
|
||||
crop_n_points_downscale_factor,
|
||||
);
|
||||
let crop_boxes = generate_crop_boxes((h, w), crop_n_layer, crop_overlap_ratio);
|
||||
for crop_box in crop_boxes.into_iter() {
|
||||
let layer_idx = crop_box.layer_idx;
|
||||
self.process_crop(img, crop_box, &point_grids[layer_idx])?
|
||||
}
|
||||
// TODO: remove duplicates
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct CropBox {
|
||||
x0: usize,
|
||||
y0: usize,
|
||||
x1: usize,
|
||||
y1: usize,
|
||||
layer_idx: usize,
|
||||
}
|
||||
|
||||
impl CropBox {
|
||||
fn new(x0: usize, y0: usize, x1: usize, y1: usize, layer_idx: usize) -> Self {
|
||||
Self {
|
||||
x0,
|
||||
y0,
|
||||
x1,
|
||||
y1,
|
||||
layer_idx,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn generate_crop_boxes(
|
||||
(im_h, im_w): (usize, usize),
|
||||
n_layers: usize,
|
||||
overlap_ratio: f64,
|
||||
) -> Vec<CropBox> {
|
||||
fn crop_len(orig_len: usize, n_crops: usize, overlap: usize) -> usize {
|
||||
f64::ceil((overlap * (n_crops - 1) + orig_len) as f64 / n_crops as f64) as usize
|
||||
}
|
||||
|
||||
let short_side = usize::min(im_h, im_w);
|
||||
|
||||
let mut crop_boxes = Vec::new();
|
||||
|
||||
// Original image.
|
||||
crop_boxes.push(CropBox::new(0, 0, im_w, im_h, 0));
|
||||
|
||||
for layer_idx in 1..=n_layers {
|
||||
let n_crops_per_side = 1 << layer_idx;
|
||||
let overlap = (overlap_ratio * short_side as f64 * 2. / n_crops_per_side as f64) as usize;
|
||||
let crop_w = crop_len(im_w, n_crops_per_side, overlap);
|
||||
let crop_h = crop_len(im_w, n_crops_per_side, overlap);
|
||||
|
||||
for i_x in 0..n_crops_per_side {
|
||||
let x0 = (crop_w - overlap) * i_x;
|
||||
for i_y in 0..n_crops_per_side {
|
||||
let y0 = (crop_h - overlap) * i_y;
|
||||
let x1 = usize::min(im_w, x0 + crop_w);
|
||||
let y1 = usize::min(im_h, y0 + crop_h);
|
||||
crop_boxes.push(CropBox::new(x0, y0, x1, y1, layer_idx));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
crop_boxes
|
||||
}
|
||||
|
||||
// Generates a 2D grid of points evenly spaced in [0,1]x[0,1].
|
||||
fn build_point_grid(n_per_side: usize) -> Vec<(f64, f64)> {
|
||||
let offset = 1f64 / (2 * n_per_side) as f64;
|
||||
let mut points = Vec::with_capacity(n_per_side * n_per_side);
|
||||
for i_x in 0..n_per_side {
|
||||
let x = offset + i_x as f64 / n_per_side as f64;
|
||||
for i_y in 0..n_per_side {
|
||||
let y = offset + i_y as f64 / n_per_side as f64;
|
||||
points.push((x, y))
|
||||
}
|
||||
}
|
||||
points
|
||||
}
|
||||
|
||||
fn build_all_layer_point_grids(
|
||||
n_per_side: usize,
|
||||
n_layers: usize,
|
||||
scale_per_layer: usize,
|
||||
) -> Vec<Vec<(f64, f64)>> {
|
||||
let mut points_by_layer = Vec::with_capacity(n_layers + 1);
|
||||
for i in 0..=n_layers {
|
||||
let n_points = n_per_side / scale_per_layer.pow(i as u32);
|
||||
points_by_layer.push(build_point_grid(n_points))
|
||||
}
|
||||
points_by_layer
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
use candle::{DType, IndexOp, Result, Tensor, D};
|
||||
use candle::{Result, Tensor};
|
||||
use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
|
||||
|
||||
#[derive(Debug)]
|
||||
@ -7,7 +7,6 @@ struct Attention {
|
||||
k_proj: Linear,
|
||||
v_proj: Linear,
|
||||
out_proj: Linear,
|
||||
internal_dim: usize,
|
||||
num_heads: usize,
|
||||
}
|
||||
|
||||
@ -28,7 +27,6 @@ impl Attention {
|
||||
k_proj,
|
||||
v_proj,
|
||||
out_proj,
|
||||
internal_dim,
|
||||
num_heads,
|
||||
})
|
||||
}
|
||||
@ -85,7 +83,6 @@ impl TwoWayAttentionBlock {
|
||||
skip_first_layer_pe: bool,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let self_attn = Attention::new(embedding_dim, num_heads, 1, vb.pp("self_attn"))?;
|
||||
let norm1 = layer_norm(embedding_dim, 1e-5, vb.pp("norm1"))?;
|
||||
let norm2 = layer_norm(embedding_dim, 1e-5, vb.pp("norm2"))?;
|
||||
let norm3 = layer_norm(embedding_dim, 1e-5, vb.pp("norm3"))?;
|
||||
@ -204,7 +201,6 @@ impl TwoWayTransformer {
|
||||
image_pe: &Tensor,
|
||||
point_embedding: &Tensor,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
let (bs, c, h, w) = image_embedding.dims4()?;
|
||||
let image_embedding = image_embedding.flatten_from(2)?.permute((0, 2, 1))?;
|
||||
let image_pe = image_pe.flatten_from(2)?.permute((0, 2, 1))?;
|
||||
|
||||
|
@ -19,10 +19,11 @@ pub fn device(cpu: bool) -> Result<Device> {
|
||||
pub fn load_image<P: AsRef<std::path::Path>>(
|
||||
p: P,
|
||||
resize_longest: Option<usize>,
|
||||
) -> Result<Tensor> {
|
||||
) -> Result<(Tensor, usize, usize)> {
|
||||
let img = image::io::Reader::open(p)?
|
||||
.decode()
|
||||
.map_err(candle::Error::wrap)?;
|
||||
let (initial_h, initial_w) = (img.height() as usize, img.width() as usize);
|
||||
let img = match resize_longest {
|
||||
None => img,
|
||||
Some(resize_longest) => {
|
||||
@ -41,7 +42,8 @@ pub fn load_image<P: AsRef<std::path::Path>>(
|
||||
let (height, width) = (img.height() as usize, img.width() as usize);
|
||||
let img = img.to_rgb8();
|
||||
let data = img.into_raw();
|
||||
Tensor::from_vec(data, (height, width, 3), &Device::Cpu)?.permute((2, 0, 1))
|
||||
let data = Tensor::from_vec(data, (height, width, 3), &Device::Cpu)?.permute((2, 0, 1))?;
|
||||
Ok((data, initial_h, initial_w))
|
||||
}
|
||||
|
||||
pub fn load_image_and_resize<P: AsRef<std::path::Path>>(
|
||||
@ -80,3 +82,27 @@ pub fn save_image<P: AsRef<std::path::Path>>(img: &Tensor, p: P) -> Result<()> {
|
||||
image.save(p).map_err(candle::Error::wrap)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn save_image_resize<P: AsRef<std::path::Path>>(
|
||||
img: &Tensor,
|
||||
p: P,
|
||||
h: usize,
|
||||
w: usize,
|
||||
) -> Result<()> {
|
||||
let p = p.as_ref();
|
||||
let (channel, height, width) = img.dims3()?;
|
||||
if channel != 3 {
|
||||
candle::bail!("save_image expects an input of shape (3, height, width)")
|
||||
}
|
||||
let img = img.permute((1, 2, 0))?.flatten_all()?;
|
||||
let pixels = img.to_vec1::<u8>()?;
|
||||
let image: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
|
||||
match image::ImageBuffer::from_raw(width as u32, height as u32, pixels) {
|
||||
Some(image) => image,
|
||||
None => candle::bail!("error saving image {p:?}"),
|
||||
};
|
||||
let image = image::DynamicImage::from(image);
|
||||
let image = image.resize_to_fill(w as u32, h as u32, image::imageops::FilterType::CatmullRom);
|
||||
image.save(p).map_err(candle::Error::wrap)?;
|
||||
Ok(())
|
||||
}
|
||||
|
Reference in New Issue
Block a user