mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Automatic mask generation (#779)
* A few more contiguous fixes for cuda. * Mask generation. * Generic bbox. * Generate all the masks.
This commit is contained in:
@ -188,13 +188,25 @@ pub fn main() -> anyhow::Result<()> {
|
||||
|
||||
if args.generate_masks {
|
||||
// Default options similar to the Python version.
|
||||
sam.generate_masks(
|
||||
let bboxes = sam.generate_masks(
|
||||
&image,
|
||||
/* points_per_side */ 32,
|
||||
/* crop_n_layer */ 0,
|
||||
/* crop_overlap_ratio */ 512. / 1500.,
|
||||
/* crop_n_points_downscale_factor */ 1,
|
||||
)?
|
||||
)?;
|
||||
for (idx, bbox) in bboxes.iter().enumerate() {
|
||||
println!("{bbox:?}");
|
||||
let mask = (&bbox.data.to_dtype(DType::U8)? * 255.)?;
|
||||
let (h, w) = mask.dims2()?;
|
||||
let mask = mask.broadcast_as((3, h, w))?;
|
||||
candle_examples::save_image_resize(
|
||||
&mask,
|
||||
format!("sam_mask{idx}.png"),
|
||||
initial_h,
|
||||
initial_w,
|
||||
)?;
|
||||
}
|
||||
} else {
|
||||
let point = Some((args.point_x, args.point_y));
|
||||
let (mask, iou_predictions) = sam.forward(&image, point, false)?;
|
||||
|
@ -219,7 +219,7 @@ impl MaskDecoder {
|
||||
let h = mlp.forward(&mask_tokens_out.i((.., i))?)?;
|
||||
hyper_in_list.push(h)
|
||||
}
|
||||
let hyper_in = Tensor::stack(hyper_in_list.as_slice(), 1)?;
|
||||
let hyper_in = Tensor::stack(hyper_in_list.as_slice(), 1)?.contiguous()?;
|
||||
let (b, c, h, w) = upscaled_embedding.dims4()?;
|
||||
let masks = hyper_in.matmul(&upscaled_embedding.reshape((b, c, h * w))?)?;
|
||||
let masks = masks.reshape((b, (), h, w))?;
|
||||
|
@ -8,6 +8,11 @@ use crate::model_prompt_encoder::PromptEncoder;
|
||||
const PROMPT_EMBED_DIM: usize = 256;
|
||||
pub const IMAGE_SIZE: usize = 1024;
|
||||
const VIT_PATCH_SIZE: usize = 16;
|
||||
const PRED_IOU_THRESH: f32 = 0.88;
|
||||
const STABILITY_SCORE_OFFSET: f32 = 1.0;
|
||||
const STABILITY_SCORE_THRESHOLD: f32 = 0.95;
|
||||
const MODEL_MASK_THRESHOLD: f32 = 0.0;
|
||||
const CROP_NMS_THRESH: f32 = 0.7;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Sam {
|
||||
@ -129,7 +134,12 @@ impl Sam {
|
||||
img.pad_with_zeros(2, 0, IMAGE_SIZE - w)
|
||||
}
|
||||
|
||||
fn process_crop(&self, img: &Tensor, cb: CropBox, point_grids: &[(f64, f64)]) -> Result<()> {
|
||||
fn process_crop(
|
||||
&self,
|
||||
img: &Tensor,
|
||||
cb: CropBox,
|
||||
point_grids: &[(f64, f64)],
|
||||
) -> Result<Vec<candle_examples::object_detection::Bbox<Tensor>>> {
|
||||
// Crop the image and calculate embeddings.
|
||||
let img = img.i((.., cb.y0..cb.y1, cb.x0..cb.x1))?;
|
||||
let img = self.preprocess(&img)?.unsqueeze(0)?;
|
||||
@ -144,28 +154,86 @@ impl Sam {
|
||||
.iter()
|
||||
.map(|&(x, y)| vec![x as f32 * crop_w as f32, y as f32 * crop_h as f32])
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mut bboxes = Vec::new();
|
||||
for points in points.chunks(64) {
|
||||
// Run the model on this batch.
|
||||
let points_len = points.len();
|
||||
let in_points = Tensor::new(points.to_vec(), img.device())?.unsqueeze(1)?;
|
||||
let in_labels = Tensor::ones((points_len, 1), DType::F32, img.device())?;
|
||||
let (sparse_prompt_embeddings, dense_prompt_embeddings) =
|
||||
self.prompt_encoder
|
||||
.forward(Some((&in_points, &in_labels)), None, None)?;
|
||||
let (_low_res_mask, iou_predictions) = self.mask_decoder.forward(
|
||||
|
||||
let (low_res_mask, iou_predictions) = self.mask_decoder.forward(
|
||||
&img_embeddings,
|
||||
&image_pe,
|
||||
&sparse_prompt_embeddings,
|
||||
&dense_prompt_embeddings,
|
||||
/* multimask_output */ true,
|
||||
)?;
|
||||
let low_res_mask = low_res_mask.flatten(0, 1)?;
|
||||
let iou_predictions = iou_predictions.flatten(0, 1)?.to_vec1::<f32>()?;
|
||||
let dev = low_res_mask.device();
|
||||
|
||||
println!("{cb:?} {iou_predictions}");
|
||||
for (i, iou) in iou_predictions.iter().enumerate() {
|
||||
// Filter by predicted IoU.
|
||||
if *iou < PRED_IOU_THRESH {
|
||||
continue;
|
||||
}
|
||||
let low_res_mask = low_res_mask.get(i)?;
|
||||
|
||||
// Calculate stability score.
|
||||
let bound = Tensor::new(MODEL_MASK_THRESHOLD + STABILITY_SCORE_OFFSET, dev)?
|
||||
.broadcast_as(low_res_mask.shape())?;
|
||||
let intersections = low_res_mask
|
||||
.ge(&bound)?
|
||||
.to_dtype(DType::F32)?
|
||||
.sum_all()?
|
||||
.to_vec0::<f32>()?;
|
||||
let bound = Tensor::new(MODEL_MASK_THRESHOLD - STABILITY_SCORE_OFFSET, dev)?
|
||||
.broadcast_as(low_res_mask.shape())?;
|
||||
let unions = low_res_mask
|
||||
.ge(&bound)?
|
||||
.to_dtype(DType::F32)?
|
||||
.sum_all()?
|
||||
.to_vec0::<f32>()?;
|
||||
let stability_score = intersections / unions;
|
||||
if stability_score < STABILITY_SCORE_THRESHOLD {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Threshold masks and calculate boxes.
|
||||
let low_res_mask = low_res_mask
|
||||
.ge(&Tensor::new(0f32, dev)?.broadcast_as(low_res_mask.shape())?)?
|
||||
.to_dtype(DType::U32)?;
|
||||
let low_res_mask_per_x = low_res_mask.sum(0)?.to_vec1::<u32>()?;
|
||||
let low_res_mask_per_y = low_res_mask.sum(1)?.to_vec1::<u32>()?;
|
||||
let min_max_x = min_max_indexes(&low_res_mask_per_x);
|
||||
let min_max_y = min_max_indexes(&low_res_mask_per_y);
|
||||
if let Some(((x0, x1), (y0, y1))) = min_max_x.zip(min_max_y) {
|
||||
let bbox = candle_examples::object_detection::Bbox {
|
||||
xmin: x0 as f32,
|
||||
ymin: y0 as f32,
|
||||
xmax: x1 as f32,
|
||||
ymax: y1 as f32,
|
||||
confidence: *iou,
|
||||
data: low_res_mask,
|
||||
};
|
||||
bboxes.push(bbox);
|
||||
}
|
||||
// TODO:
|
||||
// Filter boxes that touch crop boundaries
|
||||
// Compress to RLE.
|
||||
}
|
||||
}
|
||||
|
||||
let mut bboxes = vec![bboxes];
|
||||
// Remove duplicates within this crop.
|
||||
candle_examples::object_detection::non_maximum_suppression(&mut bboxes, CROP_NMS_THRESH);
|
||||
|
||||
// Return to the original image frame.
|
||||
Ok(())
|
||||
// TODO: Return to the original image frame.
|
||||
Ok(bboxes.remove(0))
|
||||
}
|
||||
|
||||
pub fn generate_masks(
|
||||
@ -175,7 +243,7 @@ impl Sam {
|
||||
crop_n_layer: usize,
|
||||
crop_overlap_ratio: f64,
|
||||
crop_n_points_downscale_factor: usize,
|
||||
) -> Result<()> {
|
||||
) -> Result<Vec<candle_examples::object_detection::Bbox<Tensor>>> {
|
||||
let (_c, h, w) = img.dims3()?;
|
||||
let point_grids = build_all_layer_point_grids(
|
||||
points_per_side,
|
||||
@ -183,12 +251,31 @@ impl Sam {
|
||||
crop_n_points_downscale_factor,
|
||||
);
|
||||
let crop_boxes = generate_crop_boxes((h, w), crop_n_layer, crop_overlap_ratio);
|
||||
let mut bboxes = Vec::new();
|
||||
for crop_box in crop_boxes.into_iter() {
|
||||
let layer_idx = crop_box.layer_idx;
|
||||
self.process_crop(img, crop_box, &point_grids[layer_idx])?
|
||||
let b = self.process_crop(img, crop_box, &point_grids[layer_idx])?;
|
||||
bboxes.extend(b)
|
||||
}
|
||||
// TODO: remove duplicates
|
||||
Ok(())
|
||||
Ok(bboxes)
|
||||
}
|
||||
}
|
||||
|
||||
// Return the first and last indexes i for which values[i] > 0
|
||||
fn min_max_indexes(values: &[u32]) -> Option<(usize, usize)> {
|
||||
let (mut min_i, mut max_i) = (usize::MAX, usize::MIN);
|
||||
for (i, &s) in values.iter().enumerate() {
|
||||
if s == 0 {
|
||||
continue;
|
||||
}
|
||||
min_i = usize::min(i, min_i);
|
||||
max_i = usize::max(i, max_i);
|
||||
}
|
||||
if max_i < min_i {
|
||||
None
|
||||
} else {
|
||||
Some((min_i, max_i))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -45,9 +45,9 @@ impl Attention {
|
||||
}
|
||||
|
||||
fn forward(&self, q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
|
||||
let q = self.q_proj.forward(q)?;
|
||||
let k = self.k_proj.forward(k)?;
|
||||
let v = self.v_proj.forward(v)?;
|
||||
let q = self.q_proj.forward(&q.contiguous()?)?;
|
||||
let k = self.k_proj.forward(&k.contiguous()?)?;
|
||||
let v = self.v_proj.forward(&v.contiguous()?)?;
|
||||
|
||||
let q = self.separate_heads(&q)?;
|
||||
let k = self.separate_heads(&k)?;
|
||||
|
Reference in New Issue
Block a user