mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Move the yolo shared bits to a common place. (#548)
* Move the yolo shared bits to a common place. * Share more code. * Configurable thresholds.
This commit is contained in:
@ -4,7 +4,7 @@ extern crate intel_mkl_src;
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
mod coco_classes;
|
||||
use candle_examples::object_detection::{non_maximum_suppression, Bbox};
|
||||
mod darknet;
|
||||
|
||||
use anyhow::Result;
|
||||
@ -13,30 +13,6 @@ use candle_nn::{Module, VarBuilder};
|
||||
use clap::Parser;
|
||||
use image::{DynamicImage, ImageBuffer};
|
||||
|
||||
const CONFIDENCE_THRESHOLD: f32 = 0.5;
|
||||
const NMS_THRESHOLD: f32 = 0.4;
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
struct Bbox {
|
||||
xmin: f32,
|
||||
ymin: f32,
|
||||
xmax: f32,
|
||||
ymax: f32,
|
||||
confidence: f32,
|
||||
}
|
||||
|
||||
// Intersection over union of two bounding boxes.
|
||||
fn iou(b1: &Bbox, b2: &Bbox) -> f32 {
|
||||
let b1_area = (b1.xmax - b1.xmin + 1.) * (b1.ymax - b1.ymin + 1.);
|
||||
let b2_area = (b2.xmax - b2.xmin + 1.) * (b2.ymax - b2.ymin + 1.);
|
||||
let i_xmin = b1.xmin.max(b2.xmin);
|
||||
let i_xmax = b1.xmax.min(b2.xmax);
|
||||
let i_ymin = b1.ymin.max(b2.ymin);
|
||||
let i_ymax = b1.ymax.min(b2.ymax);
|
||||
let i_area = (i_xmax - i_xmin + 1.).max(0.) * (i_ymax - i_ymin + 1.).max(0.);
|
||||
i_area / (b1_area + b2_area - i_area)
|
||||
}
|
||||
|
||||
// Assumes x1 <= x2 and y1 <= y2
|
||||
pub fn draw_rect(
|
||||
img: &mut ImageBuffer<image::Rgb<u8>, Vec<u8>>,
|
||||
@ -59,7 +35,14 @@ pub fn draw_rect(
|
||||
}
|
||||
}
|
||||
|
||||
pub fn report(pred: &Tensor, img: DynamicImage, w: usize, h: usize) -> Result<DynamicImage> {
|
||||
pub fn report(
|
||||
pred: &Tensor,
|
||||
img: DynamicImage,
|
||||
w: usize,
|
||||
h: usize,
|
||||
confidence_threshold: f32,
|
||||
nms_threshold: f32,
|
||||
) -> Result<DynamicImage> {
|
||||
let (npreds, pred_size) = pred.dims2()?;
|
||||
let nclasses = pred_size - 5;
|
||||
// The bounding boxes grouped by (maximum) class index.
|
||||
@ -68,7 +51,7 @@ pub fn report(pred: &Tensor, img: DynamicImage, w: usize, h: usize) -> Result<Dy
|
||||
for index in 0..npreds {
|
||||
let pred = Vec::<f32>::try_from(pred.get(index)?)?;
|
||||
let confidence = pred[4];
|
||||
if confidence > CONFIDENCE_THRESHOLD {
|
||||
if confidence > confidence_threshold {
|
||||
let mut class_index = 0;
|
||||
for i in 0..nclasses {
|
||||
if pred[5 + i] > pred[5 + class_index] {
|
||||
@ -87,26 +70,7 @@ pub fn report(pred: &Tensor, img: DynamicImage, w: usize, h: usize) -> Result<Dy
|
||||
}
|
||||
}
|
||||
}
|
||||
// Perform non-maximum suppression.
|
||||
for bboxes_for_class in bboxes.iter_mut() {
|
||||
bboxes_for_class.sort_by(|b1, b2| b2.confidence.partial_cmp(&b1.confidence).unwrap());
|
||||
let mut current_index = 0;
|
||||
for index in 0..bboxes_for_class.len() {
|
||||
let mut drop = false;
|
||||
for prev_index in 0..current_index {
|
||||
let iou = iou(&bboxes_for_class[prev_index], &bboxes_for_class[index]);
|
||||
if iou > NMS_THRESHOLD {
|
||||
drop = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if !drop {
|
||||
bboxes_for_class.swap(current_index, index);
|
||||
current_index += 1;
|
||||
}
|
||||
}
|
||||
bboxes_for_class.truncate(current_index);
|
||||
}
|
||||
non_maximum_suppression(&mut bboxes, nms_threshold);
|
||||
// Annotate the original image and print boxes information.
|
||||
let (initial_h, initial_w) = (img.height(), img.width());
|
||||
let w_ratio = initial_w as f32 / w as f32;
|
||||
@ -114,7 +78,11 @@ pub fn report(pred: &Tensor, img: DynamicImage, w: usize, h: usize) -> Result<Dy
|
||||
let mut img = img.to_rgb8();
|
||||
for (class_index, bboxes_for_class) in bboxes.iter().enumerate() {
|
||||
for b in bboxes_for_class.iter() {
|
||||
println!("{}: {:?}", coco_classes::NAMES[class_index], b);
|
||||
println!(
|
||||
"{}: {:?}",
|
||||
candle_examples::coco_classes::NAMES[class_index],
|
||||
b
|
||||
);
|
||||
let xmin = ((b.xmin * w_ratio) as u32).clamp(0, initial_w - 1);
|
||||
let ymin = ((b.ymin * h_ratio) as u32).clamp(0, initial_h - 1);
|
||||
let xmax = ((b.xmax * w_ratio) as u32).clamp(0, initial_w - 1);
|
||||
@ -136,6 +104,14 @@ struct Args {
|
||||
config: Option<String>,
|
||||
|
||||
images: Vec<String>,
|
||||
|
||||
/// Threshold for the model confidence level.
|
||||
#[arg(long, default_value_t = 0.5)]
|
||||
confidence_threshold: f32,
|
||||
|
||||
/// Threshold for non-maximum suppression.
|
||||
#[arg(long, default_value_t = 0.4)]
|
||||
nms_threshold: f32,
|
||||
}
|
||||
|
||||
impl Args {
|
||||
@ -200,7 +176,14 @@ pub fn main() -> Result<()> {
|
||||
let image = (image.unsqueeze(0)?.to_dtype(DType::F32)? * (1. / 255.))?;
|
||||
let predictions = model.forward(&image)?.squeeze(0)?;
|
||||
println!("generated predictions {predictions:?}");
|
||||
let image = report(&predictions, original_image, net_width, net_height)?;
|
||||
let image = report(
|
||||
&predictions,
|
||||
original_image,
|
||||
net_width,
|
||||
net_height,
|
||||
args.confidence_threshold,
|
||||
args.nms_threshold,
|
||||
)?;
|
||||
image_name.set_extension("pp.jpg");
|
||||
println!("writing {image_name:?}");
|
||||
image.save(image_name)?
|
||||
|
@ -1,82 +0,0 @@
|
||||
pub const NAMES: [&str; 80] = [
|
||||
"person",
|
||||
"bicycle",
|
||||
"car",
|
||||
"motorbike",
|
||||
"aeroplane",
|
||||
"bus",
|
||||
"train",
|
||||
"truck",
|
||||
"boat",
|
||||
"traffic light",
|
||||
"fire hydrant",
|
||||
"stop sign",
|
||||
"parking meter",
|
||||
"bench",
|
||||
"bird",
|
||||
"cat",
|
||||
"dog",
|
||||
"horse",
|
||||
"sheep",
|
||||
"cow",
|
||||
"elephant",
|
||||
"bear",
|
||||
"zebra",
|
||||
"giraffe",
|
||||
"backpack",
|
||||
"umbrella",
|
||||
"handbag",
|
||||
"tie",
|
||||
"suitcase",
|
||||
"frisbee",
|
||||
"skis",
|
||||
"snowboard",
|
||||
"sports ball",
|
||||
"kite",
|
||||
"baseball bat",
|
||||
"baseball glove",
|
||||
"skateboard",
|
||||
"surfboard",
|
||||
"tennis racket",
|
||||
"bottle",
|
||||
"wine glass",
|
||||
"cup",
|
||||
"fork",
|
||||
"knife",
|
||||
"spoon",
|
||||
"bowl",
|
||||
"banana",
|
||||
"apple",
|
||||
"sandwich",
|
||||
"orange",
|
||||
"broccoli",
|
||||
"carrot",
|
||||
"hot dog",
|
||||
"pizza",
|
||||
"donut",
|
||||
"cake",
|
||||
"chair",
|
||||
"sofa",
|
||||
"pottedplant",
|
||||
"bed",
|
||||
"diningtable",
|
||||
"toilet",
|
||||
"tvmonitor",
|
||||
"laptop",
|
||||
"mouse",
|
||||
"remote",
|
||||
"keyboard",
|
||||
"cell phone",
|
||||
"microwave",
|
||||
"oven",
|
||||
"toaster",
|
||||
"sink",
|
||||
"refrigerator",
|
||||
"book",
|
||||
"clock",
|
||||
"vase",
|
||||
"scissors",
|
||||
"teddy bear",
|
||||
"hair drier",
|
||||
"toothbrush",
|
||||
];
|
@ -4,18 +4,14 @@ extern crate intel_mkl_src;
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
mod coco_classes;
|
||||
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||
use candle_examples::object_detection::{non_maximum_suppression, Bbox};
|
||||
use candle_nn::{
|
||||
batch_norm, conv2d, conv2d_no_bias, BatchNorm, Conv2d, Conv2dConfig, Module, VarBuilder,
|
||||
};
|
||||
use clap::{Parser, ValueEnum};
|
||||
use image::{DynamicImage, ImageBuffer};
|
||||
|
||||
const CONFIDENCE_THRESHOLD: f32 = 0.5;
|
||||
const NMS_THRESHOLD: f32 = 0.4;
|
||||
|
||||
// Model architecture from https://github.com/ultralytics/ultralytics/issues/189
|
||||
// https://github.com/tinygrad/tinygrad/blob/master/examples/yolov8.py
|
||||
|
||||
@ -606,27 +602,6 @@ impl Module for YoloV8 {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
struct Bbox {
|
||||
xmin: f32,
|
||||
ymin: f32,
|
||||
xmax: f32,
|
||||
ymax: f32,
|
||||
confidence: f32,
|
||||
}
|
||||
|
||||
// Intersection over union of two bounding boxes.
|
||||
fn iou(b1: &Bbox, b2: &Bbox) -> f32 {
|
||||
let b1_area = (b1.xmax - b1.xmin + 1.) * (b1.ymax - b1.ymin + 1.);
|
||||
let b2_area = (b2.xmax - b2.xmin + 1.) * (b2.ymax - b2.ymin + 1.);
|
||||
let i_xmin = b1.xmin.max(b2.xmin);
|
||||
let i_xmax = b1.xmax.min(b2.xmax);
|
||||
let i_ymin = b1.ymin.max(b2.ymin);
|
||||
let i_ymax = b1.ymax.min(b2.ymax);
|
||||
let i_area = (i_xmax - i_xmin + 1.).max(0.) * (i_ymax - i_ymin + 1.).max(0.);
|
||||
i_area / (b1_area + b2_area - i_area)
|
||||
}
|
||||
|
||||
// Assumes x1 <= x2 and y1 <= y2
|
||||
pub fn draw_rect(
|
||||
img: &mut ImageBuffer<image::Rgb<u8>, Vec<u8>>,
|
||||
@ -649,7 +624,14 @@ pub fn draw_rect(
|
||||
}
|
||||
}
|
||||
|
||||
pub fn report(pred: &Tensor, img: DynamicImage, w: usize, h: usize) -> Result<DynamicImage> {
|
||||
pub fn report(
|
||||
pred: &Tensor,
|
||||
img: DynamicImage,
|
||||
w: usize,
|
||||
h: usize,
|
||||
confidence_threshold: f32,
|
||||
nms_threshold: f32,
|
||||
) -> Result<DynamicImage> {
|
||||
let (pred_size, npreds) = pred.dims2()?;
|
||||
let nclasses = pred_size - 4;
|
||||
// The bounding boxes grouped by (maximum) class index.
|
||||
@ -658,7 +640,7 @@ pub fn report(pred: &Tensor, img: DynamicImage, w: usize, h: usize) -> Result<Dy
|
||||
for index in 0..npreds {
|
||||
let pred = Vec::<f32>::try_from(pred.i((.., index))?)?;
|
||||
let confidence = *pred[4..].iter().max_by(|x, y| x.total_cmp(y)).unwrap();
|
||||
if confidence > CONFIDENCE_THRESHOLD {
|
||||
if confidence > confidence_threshold {
|
||||
let mut class_index = 0;
|
||||
for i in 0..nclasses {
|
||||
if pred[4 + i] > pred[4 + class_index] {
|
||||
@ -677,26 +659,9 @@ pub fn report(pred: &Tensor, img: DynamicImage, w: usize, h: usize) -> Result<Dy
|
||||
}
|
||||
}
|
||||
}
|
||||
// Perform non-maximum suppression.
|
||||
for bboxes_for_class in bboxes.iter_mut() {
|
||||
bboxes_for_class.sort_by(|b1, b2| b2.confidence.partial_cmp(&b1.confidence).unwrap());
|
||||
let mut current_index = 0;
|
||||
for index in 0..bboxes_for_class.len() {
|
||||
let mut drop = false;
|
||||
for prev_index in 0..current_index {
|
||||
let iou = iou(&bboxes_for_class[prev_index], &bboxes_for_class[index]);
|
||||
if iou > NMS_THRESHOLD {
|
||||
drop = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if !drop {
|
||||
bboxes_for_class.swap(current_index, index);
|
||||
current_index += 1;
|
||||
}
|
||||
}
|
||||
bboxes_for_class.truncate(current_index);
|
||||
}
|
||||
|
||||
non_maximum_suppression(&mut bboxes, nms_threshold);
|
||||
|
||||
// Annotate the original image and print boxes information.
|
||||
let (initial_h, initial_w) = (img.height(), img.width());
|
||||
let w_ratio = initial_w as f32 / w as f32;
|
||||
@ -704,7 +669,11 @@ pub fn report(pred: &Tensor, img: DynamicImage, w: usize, h: usize) -> Result<Dy
|
||||
let mut img = img.to_rgb8();
|
||||
for (class_index, bboxes_for_class) in bboxes.iter().enumerate() {
|
||||
for b in bboxes_for_class.iter() {
|
||||
println!("{}: {:?}", coco_classes::NAMES[class_index], b);
|
||||
println!(
|
||||
"{}: {:?}",
|
||||
candle_examples::coco_classes::NAMES[class_index],
|
||||
b
|
||||
);
|
||||
let xmin = ((b.xmin * w_ratio) as u32).clamp(0, initial_w - 1);
|
||||
let ymin = ((b.ymin * h_ratio) as u32).clamp(0, initial_h - 1);
|
||||
let xmax = ((b.xmax * w_ratio) as u32).clamp(0, initial_w - 1);
|
||||
@ -736,6 +705,14 @@ struct Args {
|
||||
which: Which,
|
||||
|
||||
images: Vec<String>,
|
||||
|
||||
/// Threshold for the model confidence level.
|
||||
#[arg(long, default_value_t = 0.5)]
|
||||
confidence_threshold: f32,
|
||||
|
||||
/// Threshold for non-maximum suppression.
|
||||
#[arg(long, default_value_t = 0.4)]
|
||||
nms_threshold: f32,
|
||||
}
|
||||
|
||||
impl Args {
|
||||
@ -792,7 +769,14 @@ pub fn main() -> anyhow::Result<()> {
|
||||
let image = (image.unsqueeze(0)?.to_dtype(DType::F32)? * (1. / 255.))?;
|
||||
let predictions = model.forward(&image)?.squeeze(0)?;
|
||||
println!("generated predictions {predictions:?}");
|
||||
let image = report(&predictions, original_image, 640, 640)?;
|
||||
let image = report(
|
||||
&predictions,
|
||||
original_image,
|
||||
640,
|
||||
640,
|
||||
args.confidence_threshold,
|
||||
args.nms_threshold,
|
||||
)?;
|
||||
image_name.set_extension("pp.jpg");
|
||||
println!("writing {image_name:?}");
|
||||
image.save(image_name)?
|
||||
|
@ -1,3 +1,6 @@
|
||||
pub mod coco_classes;
|
||||
pub mod object_detection;
|
||||
|
||||
use candle::{Device, Result, Tensor};
|
||||
|
||||
pub fn device(cpu: bool) -> Result<Device> {
|
||||
|
44
candle-examples/src/object_detection.rs
Normal file
44
candle-examples/src/object_detection.rs
Normal file
@ -0,0 +1,44 @@
|
||||
/// A bounding box around an object.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct Bbox {
|
||||
pub xmin: f32,
|
||||
pub ymin: f32,
|
||||
pub xmax: f32,
|
||||
pub ymax: f32,
|
||||
pub confidence: f32,
|
||||
}
|
||||
|
||||
/// Intersection over union of two bounding boxes.
|
||||
pub fn iou(b1: &Bbox, b2: &Bbox) -> f32 {
|
||||
let b1_area = (b1.xmax - b1.xmin + 1.) * (b1.ymax - b1.ymin + 1.);
|
||||
let b2_area = (b2.xmax - b2.xmin + 1.) * (b2.ymax - b2.ymin + 1.);
|
||||
let i_xmin = b1.xmin.max(b2.xmin);
|
||||
let i_xmax = b1.xmax.min(b2.xmax);
|
||||
let i_ymin = b1.ymin.max(b2.ymin);
|
||||
let i_ymax = b1.ymax.min(b2.ymax);
|
||||
let i_area = (i_xmax - i_xmin + 1.).max(0.) * (i_ymax - i_ymin + 1.).max(0.);
|
||||
i_area / (b1_area + b2_area - i_area)
|
||||
}
|
||||
|
||||
pub fn non_maximum_suppression(bboxes: &mut [Vec<Bbox>], threshold: f32) {
|
||||
// Perform non-maximum suppression.
|
||||
for bboxes_for_class in bboxes.iter_mut() {
|
||||
bboxes_for_class.sort_by(|b1, b2| b2.confidence.partial_cmp(&b1.confidence).unwrap());
|
||||
let mut current_index = 0;
|
||||
for index in 0..bboxes_for_class.len() {
|
||||
let mut drop = false;
|
||||
for prev_index in 0..current_index {
|
||||
let iou = iou(&bboxes_for_class[prev_index], &bboxes_for_class[index]);
|
||||
if iou > threshold {
|
||||
drop = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if !drop {
|
||||
bboxes_for_class.swap(current_index, index);
|
||||
current_index += 1;
|
||||
}
|
||||
}
|
||||
bboxes_for_class.truncate(current_index);
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user