Move the yolo shared bits to a common place. (#548)

* Move the yolo shared bits to a common place.

* Share more code.

* Configurable thresholds.
This commit is contained in:
Laurent Mazare
2023-08-22 13:03:07 +01:00
committed by GitHub
parent 20ce3e9f39
commit bb69d89e28
6 changed files with 113 additions and 181 deletions

View File

@ -0,0 +1,82 @@
pub const NAMES: [&str; 80] = [
"person",
"bicycle",
"car",
"motorbike",
"aeroplane",
"bus",
"train",
"truck",
"boat",
"traffic light",
"fire hydrant",
"stop sign",
"parking meter",
"bench",
"bird",
"cat",
"dog",
"horse",
"sheep",
"cow",
"elephant",
"bear",
"zebra",
"giraffe",
"backpack",
"umbrella",
"handbag",
"tie",
"suitcase",
"frisbee",
"skis",
"snowboard",
"sports ball",
"kite",
"baseball bat",
"baseball glove",
"skateboard",
"surfboard",
"tennis racket",
"bottle",
"wine glass",
"cup",
"fork",
"knife",
"spoon",
"bowl",
"banana",
"apple",
"sandwich",
"orange",
"broccoli",
"carrot",
"hot dog",
"pizza",
"donut",
"cake",
"chair",
"sofa",
"pottedplant",
"bed",
"diningtable",
"toilet",
"tvmonitor",
"laptop",
"mouse",
"remote",
"keyboard",
"cell phone",
"microwave",
"oven",
"toaster",
"sink",
"refrigerator",
"book",
"clock",
"vase",
"scissors",
"teddy bear",
"hair drier",
"toothbrush",
];

View File

@ -1,3 +1,6 @@
pub mod coco_classes;
pub mod object_detection;
use candle::{Device, Result, Tensor};
pub fn device(cpu: bool) -> Result<Device> {

View File

@ -0,0 +1,44 @@
/// A bounding box around an object.
#[derive(Debug, Clone, Copy)]
pub struct Bbox {
pub xmin: f32,
pub ymin: f32,
pub xmax: f32,
pub ymax: f32,
pub confidence: f32,
}
/// Intersection over union of two bounding boxes.
pub fn iou(b1: &Bbox, b2: &Bbox) -> f32 {
let b1_area = (b1.xmax - b1.xmin + 1.) * (b1.ymax - b1.ymin + 1.);
let b2_area = (b2.xmax - b2.xmin + 1.) * (b2.ymax - b2.ymin + 1.);
let i_xmin = b1.xmin.max(b2.xmin);
let i_xmax = b1.xmax.min(b2.xmax);
let i_ymin = b1.ymin.max(b2.ymin);
let i_ymax = b1.ymax.min(b2.ymax);
let i_area = (i_xmax - i_xmin + 1.).max(0.) * (i_ymax - i_ymin + 1.).max(0.);
i_area / (b1_area + b2_area - i_area)
}
pub fn non_maximum_suppression(bboxes: &mut [Vec<Bbox>], threshold: f32) {
// Perform non-maximum suppression.
for bboxes_for_class in bboxes.iter_mut() {
bboxes_for_class.sort_by(|b1, b2| b2.confidence.partial_cmp(&b1.confidence).unwrap());
let mut current_index = 0;
for index in 0..bboxes_for_class.len() {
let mut drop = false;
for prev_index in 0..current_index {
let iou = iou(&bboxes_for_class[prev_index], &bboxes_for_class[index]);
if iou > threshold {
drop = true;
break;
}
}
if !drop {
bboxes_for_class.swap(current_index, index);
current_index += 1;
}
}
bboxes_for_class.truncate(current_index);
}
}