Soft Non-Maximum Suppression (#2400)

* Soft NMS with thresholds

* NMS Test

* Soft nms w/ boxes removed below threshold

* Soft nms test

* No longer removing bounding boxes to fit Soft-NMS focus

* Initialize confidence

* Added comments

* Refactored out updating based on IOU/sigma

* Score_threshold -> confidence_threshold for clarity

* Remove bboxes below confidence threshold

* Softnms basic functionality test

* Softnms confidence decay test

* Softnms confidence threshold test

* Softnms no overlapping bbox test

* Testing confidence after no overlap test

* Single bbox and no bbox tests

* Signify test completion

* Handling result of test functions

* Checking all pairs of bboxes instead of a forward pass

* Equal confidence overlap test

* Clarified tests for implementation

* No longer dropping boxes, just setting to 0.0

* Formatted w/ cargo
This commit is contained in:
Matthew O'Malley-Nichols
2024-08-09 22:57:52 -07:00
committed by GitHub
parent 6e6c1c99b0
commit 14db029494
2 changed files with 280 additions and 0 deletions

View File

@ -50,3 +50,61 @@ pub fn non_maximum_suppression<D>(bboxes: &mut [Vec<Bbox<D>>], threshold: f32) {
bboxes_for_class.truncate(current_index);
}
}
// Updates confidences starting at highest and comparing subsequent boxes.
fn update_confidences<D>(
bboxes_for_class: &[Bbox<D>],
updated_confidences: &mut [f32],
iou_threshold: f32,
sigma: f32,
) {
let len = bboxes_for_class.len();
for current_index in 0..len {
let current_bbox = &bboxes_for_class[current_index];
for index in (current_index + 1)..len {
let iou_val = iou(current_bbox, &bboxes_for_class[index]);
if iou_val > iou_threshold {
// Decay calculation from page 4 of: https://arxiv.org/pdf/1704.04503
let decay = (-iou_val * iou_val / sigma).exp();
let updated_confidence = bboxes_for_class[index].confidence * decay;
updated_confidences[index] = updated_confidence;
}
}
}
}
// Sorts the bounding boxes by confidence and applies soft non-maximum suppression.
// This function is based on the algorithm described in https://arxiv.org/pdf/1704.04503
pub fn soft_non_maximum_suppression<D>(
bboxes: &mut [Vec<Bbox<D>>],
iou_threshold: Option<f32>,
confidence_threshold: Option<f32>,
sigma: Option<f32>,
) {
let iou_threshold = iou_threshold.unwrap_or(0.5);
let confidence_threshold = confidence_threshold.unwrap_or(0.1);
let sigma = sigma.unwrap_or(0.5);
for bboxes_for_class in bboxes.iter_mut() {
// Sort boxes by confidence in descending order
bboxes_for_class.sort_by(|b1, b2| b2.confidence.partial_cmp(&b1.confidence).unwrap());
let mut updated_confidences = bboxes_for_class
.iter()
.map(|bbox| bbox.confidence)
.collect::<Vec<_>>();
update_confidences(
bboxes_for_class,
&mut updated_confidences,
iou_threshold,
sigma,
);
// Update confidences, set to 0.0 if below threshold
for (i, &confidence) in updated_confidences.iter().enumerate() {
bboxes_for_class[i].confidence = if confidence < confidence_threshold {
0.0
} else {
confidence
};
}
}
}