mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Automatic mask generation (#779)
* A few more contiguous fixes for cuda. * Mask generation. * Generic bbox. * Generate all the masks.
This commit is contained in:
@ -8,6 +8,11 @@ use crate::model_prompt_encoder::PromptEncoder;
|
||||
const PROMPT_EMBED_DIM: usize = 256;
|
||||
pub const IMAGE_SIZE: usize = 1024;
|
||||
const VIT_PATCH_SIZE: usize = 16;
|
||||
const PRED_IOU_THRESH: f32 = 0.88;
|
||||
const STABILITY_SCORE_OFFSET: f32 = 1.0;
|
||||
const STABILITY_SCORE_THRESHOLD: f32 = 0.95;
|
||||
const MODEL_MASK_THRESHOLD: f32 = 0.0;
|
||||
const CROP_NMS_THRESH: f32 = 0.7;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Sam {
|
||||
@ -129,7 +134,12 @@ impl Sam {
|
||||
img.pad_with_zeros(2, 0, IMAGE_SIZE - w)
|
||||
}
|
||||
|
||||
fn process_crop(&self, img: &Tensor, cb: CropBox, point_grids: &[(f64, f64)]) -> Result<()> {
|
||||
fn process_crop(
|
||||
&self,
|
||||
img: &Tensor,
|
||||
cb: CropBox,
|
||||
point_grids: &[(f64, f64)],
|
||||
) -> Result<Vec<candle_examples::object_detection::Bbox<Tensor>>> {
|
||||
// Crop the image and calculate embeddings.
|
||||
let img = img.i((.., cb.y0..cb.y1, cb.x0..cb.x1))?;
|
||||
let img = self.preprocess(&img)?.unsqueeze(0)?;
|
||||
@ -144,28 +154,86 @@ impl Sam {
|
||||
.iter()
|
||||
.map(|&(x, y)| vec![x as f32 * crop_w as f32, y as f32 * crop_h as f32])
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mut bboxes = Vec::new();
|
||||
for points in points.chunks(64) {
|
||||
// Run the model on this batch.
|
||||
let points_len = points.len();
|
||||
let in_points = Tensor::new(points.to_vec(), img.device())?.unsqueeze(1)?;
|
||||
let in_labels = Tensor::ones((points_len, 1), DType::F32, img.device())?;
|
||||
let (sparse_prompt_embeddings, dense_prompt_embeddings) =
|
||||
self.prompt_encoder
|
||||
.forward(Some((&in_points, &in_labels)), None, None)?;
|
||||
let (_low_res_mask, iou_predictions) = self.mask_decoder.forward(
|
||||
|
||||
let (low_res_mask, iou_predictions) = self.mask_decoder.forward(
|
||||
&img_embeddings,
|
||||
&image_pe,
|
||||
&sparse_prompt_embeddings,
|
||||
&dense_prompt_embeddings,
|
||||
/* multimask_output */ true,
|
||||
)?;
|
||||
let low_res_mask = low_res_mask.flatten(0, 1)?;
|
||||
let iou_predictions = iou_predictions.flatten(0, 1)?.to_vec1::<f32>()?;
|
||||
let dev = low_res_mask.device();
|
||||
|
||||
println!("{cb:?} {iou_predictions}");
|
||||
for (i, iou) in iou_predictions.iter().enumerate() {
|
||||
// Filter by predicted IoU.
|
||||
if *iou < PRED_IOU_THRESH {
|
||||
continue;
|
||||
}
|
||||
let low_res_mask = low_res_mask.get(i)?;
|
||||
|
||||
// Calculate stability score.
|
||||
let bound = Tensor::new(MODEL_MASK_THRESHOLD + STABILITY_SCORE_OFFSET, dev)?
|
||||
.broadcast_as(low_res_mask.shape())?;
|
||||
let intersections = low_res_mask
|
||||
.ge(&bound)?
|
||||
.to_dtype(DType::F32)?
|
||||
.sum_all()?
|
||||
.to_vec0::<f32>()?;
|
||||
let bound = Tensor::new(MODEL_MASK_THRESHOLD - STABILITY_SCORE_OFFSET, dev)?
|
||||
.broadcast_as(low_res_mask.shape())?;
|
||||
let unions = low_res_mask
|
||||
.ge(&bound)?
|
||||
.to_dtype(DType::F32)?
|
||||
.sum_all()?
|
||||
.to_vec0::<f32>()?;
|
||||
let stability_score = intersections / unions;
|
||||
if stability_score < STABILITY_SCORE_THRESHOLD {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Threshold masks and calculate boxes.
|
||||
let low_res_mask = low_res_mask
|
||||
.ge(&Tensor::new(0f32, dev)?.broadcast_as(low_res_mask.shape())?)?
|
||||
.to_dtype(DType::U32)?;
|
||||
let low_res_mask_per_x = low_res_mask.sum(0)?.to_vec1::<u32>()?;
|
||||
let low_res_mask_per_y = low_res_mask.sum(1)?.to_vec1::<u32>()?;
|
||||
let min_max_x = min_max_indexes(&low_res_mask_per_x);
|
||||
let min_max_y = min_max_indexes(&low_res_mask_per_y);
|
||||
if let Some(((x0, x1), (y0, y1))) = min_max_x.zip(min_max_y) {
|
||||
let bbox = candle_examples::object_detection::Bbox {
|
||||
xmin: x0 as f32,
|
||||
ymin: y0 as f32,
|
||||
xmax: x1 as f32,
|
||||
ymax: y1 as f32,
|
||||
confidence: *iou,
|
||||
data: low_res_mask,
|
||||
};
|
||||
bboxes.push(bbox);
|
||||
}
|
||||
// TODO:
|
||||
// Filter boxes that touch crop boundaries
|
||||
// Compress to RLE.
|
||||
}
|
||||
}
|
||||
|
||||
let mut bboxes = vec![bboxes];
|
||||
// Remove duplicates within this crop.
|
||||
candle_examples::object_detection::non_maximum_suppression(&mut bboxes, CROP_NMS_THRESH);
|
||||
|
||||
// Return to the original image frame.
|
||||
Ok(())
|
||||
// TODO: Return to the original image frame.
|
||||
Ok(bboxes.remove(0))
|
||||
}
|
||||
|
||||
pub fn generate_masks(
|
||||
@ -175,7 +243,7 @@ impl Sam {
|
||||
crop_n_layer: usize,
|
||||
crop_overlap_ratio: f64,
|
||||
crop_n_points_downscale_factor: usize,
|
||||
) -> Result<()> {
|
||||
) -> Result<Vec<candle_examples::object_detection::Bbox<Tensor>>> {
|
||||
let (_c, h, w) = img.dims3()?;
|
||||
let point_grids = build_all_layer_point_grids(
|
||||
points_per_side,
|
||||
@ -183,12 +251,31 @@ impl Sam {
|
||||
crop_n_points_downscale_factor,
|
||||
);
|
||||
let crop_boxes = generate_crop_boxes((h, w), crop_n_layer, crop_overlap_ratio);
|
||||
let mut bboxes = Vec::new();
|
||||
for crop_box in crop_boxes.into_iter() {
|
||||
let layer_idx = crop_box.layer_idx;
|
||||
self.process_crop(img, crop_box, &point_grids[layer_idx])?
|
||||
let b = self.process_crop(img, crop_box, &point_grids[layer_idx])?;
|
||||
bboxes.extend(b)
|
||||
}
|
||||
// TODO: remove duplicates
|
||||
Ok(())
|
||||
Ok(bboxes)
|
||||
}
|
||||
}
|
||||
|
||||
// Return the first and last indexes i for which values[i] > 0
|
||||
fn min_max_indexes(values: &[u32]) -> Option<(usize, usize)> {
|
||||
let (mut min_i, mut max_i) = (usize::MAX, usize::MIN);
|
||||
for (i, &s) in values.iter().enumerate() {
|
||||
if s == 0 {
|
||||
continue;
|
||||
}
|
||||
min_i = usize::min(i, min_i);
|
||||
max_i = usize::max(i, max_i);
|
||||
}
|
||||
if max_i < min_i {
|
||||
None
|
||||
} else {
|
||||
Some((min_i, max_i))
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user