mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 20:22:49 +00:00
Soft Non-Maximum Suppression (#2400)
* Soft NMS with thresholds * NMS Test * Soft nms w/ boxes removed below threshold * Soft nms test * No longer removing bounding boxes to fit Soft-NMS focus * Initialize confidence * Added comments * Refactored out updating based on IOU/sigma * Score_threshold -> confidence_threshold for clarity * Remove bboxes below confidence threshold * Softnms basic functionality test * Softnms confidence decay test * Softnms confidence threshold test * Softnms no overlapping bbox test * Testing confidence after no overlap test * Single bbox and no bbox tests * Signify test completion * Handling result of test functions * Checking all pairs of bboxes instead of a forward pass * Equal confidence overlap test * Clarified tests for implementation * No longer dropping boxes, just setting to 0.0 * Formatted w/ cargo
This commit is contained in:

committed by
GitHub

parent
6e6c1c99b0
commit
14db029494
@ -50,3 +50,61 @@ pub fn non_maximum_suppression<D>(bboxes: &mut [Vec<Bbox<D>>], threshold: f32) {
|
||||
bboxes_for_class.truncate(current_index);
|
||||
}
|
||||
}
|
||||
|
||||
// Updates confidences starting at highest and comparing subsequent boxes.
|
||||
fn update_confidences<D>(
|
||||
bboxes_for_class: &[Bbox<D>],
|
||||
updated_confidences: &mut [f32],
|
||||
iou_threshold: f32,
|
||||
sigma: f32,
|
||||
) {
|
||||
let len = bboxes_for_class.len();
|
||||
for current_index in 0..len {
|
||||
let current_bbox = &bboxes_for_class[current_index];
|
||||
for index in (current_index + 1)..len {
|
||||
let iou_val = iou(current_bbox, &bboxes_for_class[index]);
|
||||
if iou_val > iou_threshold {
|
||||
// Decay calculation from page 4 of: https://arxiv.org/pdf/1704.04503
|
||||
let decay = (-iou_val * iou_val / sigma).exp();
|
||||
let updated_confidence = bboxes_for_class[index].confidence * decay;
|
||||
updated_confidences[index] = updated_confidence;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Sorts the bounding boxes by confidence and applies soft non-maximum suppression.
|
||||
// This function is based on the algorithm described in https://arxiv.org/pdf/1704.04503
|
||||
pub fn soft_non_maximum_suppression<D>(
|
||||
bboxes: &mut [Vec<Bbox<D>>],
|
||||
iou_threshold: Option<f32>,
|
||||
confidence_threshold: Option<f32>,
|
||||
sigma: Option<f32>,
|
||||
) {
|
||||
let iou_threshold = iou_threshold.unwrap_or(0.5);
|
||||
let confidence_threshold = confidence_threshold.unwrap_or(0.1);
|
||||
let sigma = sigma.unwrap_or(0.5);
|
||||
|
||||
for bboxes_for_class in bboxes.iter_mut() {
|
||||
// Sort boxes by confidence in descending order
|
||||
bboxes_for_class.sort_by(|b1, b2| b2.confidence.partial_cmp(&b1.confidence).unwrap());
|
||||
let mut updated_confidences = bboxes_for_class
|
||||
.iter()
|
||||
.map(|bbox| bbox.confidence)
|
||||
.collect::<Vec<_>>();
|
||||
update_confidences(
|
||||
bboxes_for_class,
|
||||
&mut updated_confidences,
|
||||
iou_threshold,
|
||||
sigma,
|
||||
);
|
||||
// Update confidences, set to 0.0 if below threshold
|
||||
for (i, &confidence) in updated_confidences.iter().enumerate() {
|
||||
bboxes_for_class[i].confidence = if confidence < confidence_threshold {
|
||||
0.0
|
||||
} else {
|
||||
confidence
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user