Move the yolo shared bits to a common place. (#548)

* Move the yolo shared bits to a common place.

* Share more code.

* Configurable thresholds.
This commit is contained in:
Laurent Mazare
2023-08-22 13:03:07 +01:00
committed by GitHub
parent 20ce3e9f39
commit bb69d89e28
6 changed files with 113 additions and 181 deletions

View File

@ -1,82 +0,0 @@
pub const NAMES: [&str; 80] = [
"person",
"bicycle",
"car",
"motorbike",
"aeroplane",
"bus",
"train",
"truck",
"boat",
"traffic light",
"fire hydrant",
"stop sign",
"parking meter",
"bench",
"bird",
"cat",
"dog",
"horse",
"sheep",
"cow",
"elephant",
"bear",
"zebra",
"giraffe",
"backpack",
"umbrella",
"handbag",
"tie",
"suitcase",
"frisbee",
"skis",
"snowboard",
"sports ball",
"kite",
"baseball bat",
"baseball glove",
"skateboard",
"surfboard",
"tennis racket",
"bottle",
"wine glass",
"cup",
"fork",
"knife",
"spoon",
"bowl",
"banana",
"apple",
"sandwich",
"orange",
"broccoli",
"carrot",
"hot dog",
"pizza",
"donut",
"cake",
"chair",
"sofa",
"pottedplant",
"bed",
"diningtable",
"toilet",
"tvmonitor",
"laptop",
"mouse",
"remote",
"keyboard",
"cell phone",
"microwave",
"oven",
"toaster",
"sink",
"refrigerator",
"book",
"clock",
"vase",
"scissors",
"teddy bear",
"hair drier",
"toothbrush",
];

View File

@ -4,7 +4,7 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
mod coco_classes;
use candle_examples::object_detection::{non_maximum_suppression, Bbox};
mod darknet;
use anyhow::Result;
@ -13,30 +13,6 @@ use candle_nn::{Module, VarBuilder};
use clap::Parser;
use image::{DynamicImage, ImageBuffer};
const CONFIDENCE_THRESHOLD: f32 = 0.5;
const NMS_THRESHOLD: f32 = 0.4;
#[derive(Debug, Clone, Copy)]
struct Bbox {
xmin: f32,
ymin: f32,
xmax: f32,
ymax: f32,
confidence: f32,
}
// Intersection over union of two bounding boxes.
fn iou(b1: &Bbox, b2: &Bbox) -> f32 {
let b1_area = (b1.xmax - b1.xmin + 1.) * (b1.ymax - b1.ymin + 1.);
let b2_area = (b2.xmax - b2.xmin + 1.) * (b2.ymax - b2.ymin + 1.);
let i_xmin = b1.xmin.max(b2.xmin);
let i_xmax = b1.xmax.min(b2.xmax);
let i_ymin = b1.ymin.max(b2.ymin);
let i_ymax = b1.ymax.min(b2.ymax);
let i_area = (i_xmax - i_xmin + 1.).max(0.) * (i_ymax - i_ymin + 1.).max(0.);
i_area / (b1_area + b2_area - i_area)
}
// Assumes x1 <= x2 and y1 <= y2
pub fn draw_rect(
img: &mut ImageBuffer<image::Rgb<u8>, Vec<u8>>,
@ -59,7 +35,14 @@ pub fn draw_rect(
}
}
pub fn report(pred: &Tensor, img: DynamicImage, w: usize, h: usize) -> Result<DynamicImage> {
pub fn report(
pred: &Tensor,
img: DynamicImage,
w: usize,
h: usize,
confidence_threshold: f32,
nms_threshold: f32,
) -> Result<DynamicImage> {
let (npreds, pred_size) = pred.dims2()?;
let nclasses = pred_size - 5;
// The bounding boxes grouped by (maximum) class index.
@ -68,7 +51,7 @@ pub fn report(pred: &Tensor, img: DynamicImage, w: usize, h: usize) -> Result<Dy
for index in 0..npreds {
let pred = Vec::<f32>::try_from(pred.get(index)?)?;
let confidence = pred[4];
if confidence > CONFIDENCE_THRESHOLD {
if confidence > confidence_threshold {
let mut class_index = 0;
for i in 0..nclasses {
if pred[5 + i] > pred[5 + class_index] {
@ -87,26 +70,7 @@ pub fn report(pred: &Tensor, img: DynamicImage, w: usize, h: usize) -> Result<Dy
}
}
}
// Perform non-maximum suppression.
for bboxes_for_class in bboxes.iter_mut() {
bboxes_for_class.sort_by(|b1, b2| b2.confidence.partial_cmp(&b1.confidence).unwrap());
let mut current_index = 0;
for index in 0..bboxes_for_class.len() {
let mut drop = false;
for prev_index in 0..current_index {
let iou = iou(&bboxes_for_class[prev_index], &bboxes_for_class[index]);
if iou > NMS_THRESHOLD {
drop = true;
break;
}
}
if !drop {
bboxes_for_class.swap(current_index, index);
current_index += 1;
}
}
bboxes_for_class.truncate(current_index);
}
non_maximum_suppression(&mut bboxes, nms_threshold);
// Annotate the original image and print boxes information.
let (initial_h, initial_w) = (img.height(), img.width());
let w_ratio = initial_w as f32 / w as f32;
@ -114,7 +78,11 @@ pub fn report(pred: &Tensor, img: DynamicImage, w: usize, h: usize) -> Result<Dy
let mut img = img.to_rgb8();
for (class_index, bboxes_for_class) in bboxes.iter().enumerate() {
for b in bboxes_for_class.iter() {
println!("{}: {:?}", coco_classes::NAMES[class_index], b);
println!(
"{}: {:?}",
candle_examples::coco_classes::NAMES[class_index],
b
);
let xmin = ((b.xmin * w_ratio) as u32).clamp(0, initial_w - 1);
let ymin = ((b.ymin * h_ratio) as u32).clamp(0, initial_h - 1);
let xmax = ((b.xmax * w_ratio) as u32).clamp(0, initial_w - 1);
@ -136,6 +104,14 @@ struct Args {
config: Option<String>,
images: Vec<String>,
/// Threshold for the model confidence level.
#[arg(long, default_value_t = 0.5)]
confidence_threshold: f32,
/// Threshold for non-maximum suppression.
#[arg(long, default_value_t = 0.4)]
nms_threshold: f32,
}
impl Args {
@ -200,7 +176,14 @@ pub fn main() -> Result<()> {
let image = (image.unsqueeze(0)?.to_dtype(DType::F32)? * (1. / 255.))?;
let predictions = model.forward(&image)?.squeeze(0)?;
println!("generated predictions {predictions:?}");
let image = report(&predictions, original_image, net_width, net_height)?;
let image = report(
&predictions,
original_image,
net_width,
net_height,
args.confidence_threshold,
args.nms_threshold,
)?;
image_name.set_extension("pp.jpg");
println!("writing {image_name:?}");
image.save(image_name)?