Automatic mask generation (#779)

* A few more contiguous fixes for cuda.

* Mask generation.

* Generic bbox.

* Generate all the masks.
This commit is contained in:
Laurent Mazare
2023-09-08 19:11:34 +01:00
committed by GitHub
parent 158ff3c609
commit 0906acab91
7 changed files with 125 additions and 26 deletions

View File

@ -188,13 +188,25 @@ pub fn main() -> anyhow::Result<()> {
if args.generate_masks {
// Default options similar to the Python version.
sam.generate_masks(
let bboxes = sam.generate_masks(
&image,
/* points_per_side */ 32,
/* crop_n_layer */ 0,
/* crop_overlap_ratio */ 512. / 1500.,
/* crop_n_points_downscale_factor */ 1,
)?
)?;
for (idx, bbox) in bboxes.iter().enumerate() {
println!("{bbox:?}");
let mask = (&bbox.data.to_dtype(DType::U8)? * 255.)?;
let (h, w) = mask.dims2()?;
let mask = mask.broadcast_as((3, h, w))?;
candle_examples::save_image_resize(
&mask,
format!("sam_mask{idx}.png"),
initial_h,
initial_w,
)?;
}
} else {
let point = Some((args.point_x, args.point_y));
let (mask, iou_predictions) = sam.forward(&image, point, false)?;

View File

@ -219,7 +219,7 @@ impl MaskDecoder {
let h = mlp.forward(&mask_tokens_out.i((.., i))?)?;
hyper_in_list.push(h)
}
let hyper_in = Tensor::stack(hyper_in_list.as_slice(), 1)?;
let hyper_in = Tensor::stack(hyper_in_list.as_slice(), 1)?.contiguous()?;
let (b, c, h, w) = upscaled_embedding.dims4()?;
let masks = hyper_in.matmul(&upscaled_embedding.reshape((b, c, h * w))?)?;
let masks = masks.reshape((b, (), h, w))?;

View File

@ -8,6 +8,11 @@ use crate::model_prompt_encoder::PromptEncoder;
const PROMPT_EMBED_DIM: usize = 256;
pub const IMAGE_SIZE: usize = 1024;
const VIT_PATCH_SIZE: usize = 16;
const PRED_IOU_THRESH: f32 = 0.88;
const STABILITY_SCORE_OFFSET: f32 = 1.0;
const STABILITY_SCORE_THRESHOLD: f32 = 0.95;
const MODEL_MASK_THRESHOLD: f32 = 0.0;
const CROP_NMS_THRESH: f32 = 0.7;
#[derive(Debug)]
pub struct Sam {
@ -129,7 +134,12 @@ impl Sam {
img.pad_with_zeros(2, 0, IMAGE_SIZE - w)
}
fn process_crop(&self, img: &Tensor, cb: CropBox, point_grids: &[(f64, f64)]) -> Result<()> {
fn process_crop(
&self,
img: &Tensor,
cb: CropBox,
point_grids: &[(f64, f64)],
) -> Result<Vec<candle_examples::object_detection::Bbox<Tensor>>> {
// Crop the image and calculate embeddings.
let img = img.i((.., cb.y0..cb.y1, cb.x0..cb.x1))?;
let img = self.preprocess(&img)?.unsqueeze(0)?;
@ -144,28 +154,86 @@ impl Sam {
.iter()
.map(|&(x, y)| vec![x as f32 * crop_w as f32, y as f32 * crop_h as f32])
.collect::<Vec<_>>();
let mut bboxes = Vec::new();
for points in points.chunks(64) {
// Run the model on this batch.
let points_len = points.len();
let in_points = Tensor::new(points.to_vec(), img.device())?.unsqueeze(1)?;
let in_labels = Tensor::ones((points_len, 1), DType::F32, img.device())?;
let (sparse_prompt_embeddings, dense_prompt_embeddings) =
self.prompt_encoder
.forward(Some((&in_points, &in_labels)), None, None)?;
let (_low_res_mask, iou_predictions) = self.mask_decoder.forward(
let (low_res_mask, iou_predictions) = self.mask_decoder.forward(
&img_embeddings,
&image_pe,
&sparse_prompt_embeddings,
&dense_prompt_embeddings,
/* multimask_output */ true,
)?;
let low_res_mask = low_res_mask.flatten(0, 1)?;
let iou_predictions = iou_predictions.flatten(0, 1)?.to_vec1::<f32>()?;
let dev = low_res_mask.device();
println!("{cb:?} {iou_predictions}");
for (i, iou) in iou_predictions.iter().enumerate() {
// Filter by predicted IoU.
if *iou < PRED_IOU_THRESH {
continue;
}
let low_res_mask = low_res_mask.get(i)?;
// Calculate stability score.
let bound = Tensor::new(MODEL_MASK_THRESHOLD + STABILITY_SCORE_OFFSET, dev)?
.broadcast_as(low_res_mask.shape())?;
let intersections = low_res_mask
.ge(&bound)?
.to_dtype(DType::F32)?
.sum_all()?
.to_vec0::<f32>()?;
let bound = Tensor::new(MODEL_MASK_THRESHOLD - STABILITY_SCORE_OFFSET, dev)?
.broadcast_as(low_res_mask.shape())?;
let unions = low_res_mask
.ge(&bound)?
.to_dtype(DType::F32)?
.sum_all()?
.to_vec0::<f32>()?;
let stability_score = intersections / unions;
if stability_score < STABILITY_SCORE_THRESHOLD {
continue;
}
// Threshold masks and calculate boxes.
let low_res_mask = low_res_mask
.ge(&Tensor::new(0f32, dev)?.broadcast_as(low_res_mask.shape())?)?
.to_dtype(DType::U32)?;
let low_res_mask_per_x = low_res_mask.sum(0)?.to_vec1::<u32>()?;
let low_res_mask_per_y = low_res_mask.sum(1)?.to_vec1::<u32>()?;
let min_max_x = min_max_indexes(&low_res_mask_per_x);
let min_max_y = min_max_indexes(&low_res_mask_per_y);
if let Some(((x0, x1), (y0, y1))) = min_max_x.zip(min_max_y) {
let bbox = candle_examples::object_detection::Bbox {
xmin: x0 as f32,
ymin: y0 as f32,
xmax: x1 as f32,
ymax: y1 as f32,
confidence: *iou,
data: low_res_mask,
};
bboxes.push(bbox);
}
// TODO:
// Filter boxes that touch crop boundaries
// Compress to RLE.
}
}
let mut bboxes = vec![bboxes];
// Remove duplicates within this crop.
candle_examples::object_detection::non_maximum_suppression(&mut bboxes, CROP_NMS_THRESH);
// Return to the original image frame.
Ok(())
// TODO: Return to the original image frame.
Ok(bboxes.remove(0))
}
pub fn generate_masks(
@ -175,7 +243,7 @@ impl Sam {
crop_n_layer: usize,
crop_overlap_ratio: f64,
crop_n_points_downscale_factor: usize,
) -> Result<()> {
) -> Result<Vec<candle_examples::object_detection::Bbox<Tensor>>> {
let (_c, h, w) = img.dims3()?;
let point_grids = build_all_layer_point_grids(
points_per_side,
@ -183,12 +251,31 @@ impl Sam {
crop_n_points_downscale_factor,
);
let crop_boxes = generate_crop_boxes((h, w), crop_n_layer, crop_overlap_ratio);
let mut bboxes = Vec::new();
for crop_box in crop_boxes.into_iter() {
let layer_idx = crop_box.layer_idx;
self.process_crop(img, crop_box, &point_grids[layer_idx])?
let b = self.process_crop(img, crop_box, &point_grids[layer_idx])?;
bboxes.extend(b)
}
// TODO: remove duplicates
Ok(())
Ok(bboxes)
}
}
// Return the first and last indexes i for which values[i] > 0
fn min_max_indexes(values: &[u32]) -> Option<(usize, usize)> {
let (mut min_i, mut max_i) = (usize::MAX, usize::MIN);
for (i, &s) in values.iter().enumerate() {
if s == 0 {
continue;
}
min_i = usize::min(i, min_i);
max_i = usize::max(i, max_i);
}
if max_i < min_i {
None
} else {
Some((min_i, max_i))
}
}

View File

@ -45,9 +45,9 @@ impl Attention {
}
fn forward(&self, q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
let q = self.q_proj.forward(q)?;
let k = self.k_proj.forward(k)?;
let v = self.v_proj.forward(v)?;
let q = self.q_proj.forward(&q.contiguous()?)?;
let k = self.k_proj.forward(&k.contiguous()?)?;
let v = self.v_proj.forward(&v.contiguous()?)?;
let q = self.separate_heads(&q)?;
let k = self.separate_heads(&k)?;

View File

@ -46,7 +46,7 @@ pub fn report(
let (npreds, pred_size) = pred.dims2()?;
let nclasses = pred_size - 5;
// The bounding boxes grouped by (maximum) class index.
let mut bboxes: Vec<Vec<Bbox>> = (0..nclasses).map(|_| vec![]).collect();
let mut bboxes: Vec<Vec<Bbox<()>>> = (0..nclasses).map(|_| vec![]).collect();
// Extract the bounding boxes for which confidence is above the threshold.
for index in 0..npreds {
let pred = Vec::<f32>::try_from(pred.get(index)?)?;
@ -65,7 +65,7 @@ pub fn report(
xmax: pred[0] + pred[2] / 2.,
ymax: pred[1] + pred[3] / 2.,
confidence,
keypoints: vec![],
data: (),
};
bboxes[class_index].push(bbox)
}

View File

@ -64,7 +64,7 @@ pub fn report_detect(
let (pred_size, npreds) = pred.dims2()?;
let nclasses = pred_size - 4;
// The bounding boxes grouped by (maximum) class index.
let mut bboxes: Vec<Vec<Bbox>> = (0..nclasses).map(|_| vec![]).collect();
let mut bboxes: Vec<Vec<Bbox<Vec<KeyPoint>>>> = (0..nclasses).map(|_| vec![]).collect();
// Extract the bounding boxes for which confidence is above the threshold.
for index in 0..npreds {
let pred = Vec::<f32>::try_from(pred.i((.., index))?)?;
@ -83,7 +83,7 @@ pub fn report_detect(
xmax: pred[0] + pred[2] / 2.,
ymax: pred[1] + pred[3] / 2.,
confidence,
keypoints: vec![],
data: vec![],
};
bboxes[class_index].push(bbox)
}
@ -176,7 +176,7 @@ pub fn report_pose(
xmax: pred[0] + pred[2] / 2.,
ymax: pred[1] + pred[3] / 2.,
confidence,
keypoints,
data: keypoints,
};
bboxes.push(bbox)
}
@ -204,7 +204,7 @@ pub fn report_pose(
image::Rgb([255, 0, 0]),
);
}
for kp in b.keypoints.iter() {
for kp in b.data.iter() {
if kp.mask < 0.6 {
continue;
}
@ -219,8 +219,8 @@ pub fn report_pose(
}
for &(idx1, idx2) in KP_CONNECTIONS.iter() {
let kp1 = &b.keypoints[idx1];
let kp2 = &b.keypoints[idx2];
let kp1 = &b.data[idx1];
let kp2 = &b.data[idx2];
if kp1.mask < 0.6 || kp2.mask < 0.6 {
continue;
}

View File

@ -1,12 +1,12 @@
/// A bounding box around an object.
#[derive(Debug, Clone)]
pub struct Bbox {
pub struct Bbox<D> {
pub xmin: f32,
pub ymin: f32,
pub xmax: f32,
pub ymax: f32,
pub confidence: f32,
pub keypoints: Vec<KeyPoint>,
pub data: D,
}
#[derive(Debug, Clone, Copy, PartialEq)]
@ -17,7 +17,7 @@ pub struct KeyPoint {
}
/// Intersection over union of two bounding boxes.
pub fn iou(b1: &Bbox, b2: &Bbox) -> f32 {
pub fn iou<D>(b1: &Bbox<D>, b2: &Bbox<D>) -> f32 {
let b1_area = (b1.xmax - b1.xmin + 1.) * (b1.ymax - b1.ymin + 1.);
let b2_area = (b2.xmax - b2.xmin + 1.) * (b2.ymax - b2.ymin + 1.);
let i_xmin = b1.xmin.max(b2.xmin);
@ -28,7 +28,7 @@ pub fn iou(b1: &Bbox, b2: &Bbox) -> f32 {
i_area / (b1_area + b2_area - i_area)
}
pub fn non_maximum_suppression(bboxes: &mut [Vec<Bbox>], threshold: f32) {
pub fn non_maximum_suppression<D>(bboxes: &mut [Vec<Bbox<D>>], threshold: f32) {
// Perform non-maximum suppression.
for bboxes_for_class in bboxes.iter_mut() {
bboxes_for_class.sort_by(|b1, b2| b2.confidence.partial_cmp(&b1.confidence).unwrap());