From 8bc5fffa45a8431f393e1b077b528f9850a15378 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Fri, 25 Aug 2023 11:21:11 +0100 Subject: [PATCH] More support for pose estimation in yolo-v8. (#599) * More support for pose estimation in yolo-v8. * Support both object detection and pose-estimation in the yolo-v8 example. --- candle-examples/examples/yolo-v3/main.rs | 1 + candle-examples/examples/yolo-v8/main.rs | 176 +++++++++++++++++++++-- candle-examples/src/object_detection.rs | 3 +- 3 files changed, 164 insertions(+), 16 deletions(-) diff --git a/candle-examples/examples/yolo-v3/main.rs b/candle-examples/examples/yolo-v3/main.rs index 514e9a0c..5e388921 100644 --- a/candle-examples/examples/yolo-v3/main.rs +++ b/candle-examples/examples/yolo-v3/main.rs @@ -65,6 +65,7 @@ pub fn report( xmax: pred[0] + pred[2] / 2., ymax: pred[1] + pred[3] / 2., confidence, + keypoints: vec![], }; bboxes[class_index].push(bbox) } diff --git a/candle-examples/examples/yolo-v8/main.rs b/candle-examples/examples/yolo-v8/main.rs index a93aa035..939475a2 100644 --- a/candle-examples/examples/yolo-v8/main.rs +++ b/candle-examples/examples/yolo-v8/main.rs @@ -16,7 +16,7 @@ use image::{DynamicImage, ImageBuffer}; // https://github.com/tinygrad/tinygrad/blob/master/examples/yolov8.py #[derive(Clone, Copy, PartialEq, Debug)] -struct Multiples { +pub struct Multiples { depth: f64, width: f64, ratio: f64, @@ -727,7 +727,7 @@ pub fn draw_rect( } } -pub fn report( +pub fn report_detect( pred: &Tensor, img: DynamicImage, w: usize, @@ -757,6 +757,7 @@ pub fn report( xmax: pred[0] + pred[2] / 2., ymax: pred[1] + pred[3] / 2., confidence, + keypoints: vec![], }; bboxes[class_index].push(bbox) } @@ -787,6 +788,85 @@ pub fn report( Ok(DynamicImage::ImageRgb8(img)) } +pub fn report_pose( + pred: &Tensor, + img: DynamicImage, + w: usize, + h: usize, + confidence_threshold: f32, + nms_threshold: f32, +) -> Result { + let (pred_size, npreds) = pred.dims2()?; + let nclasses = pred_size - 4; + // The bounding boxes grouped by (maximum) class index. + let mut bboxes: Vec> = (0..nclasses).map(|_| vec![]).collect(); + // Extract the bounding boxes for which confidence is above the threshold. + for index in 0..npreds { + let pred = Vec::::try_from(pred.i((.., index))?)?; + let confidence = pred[4]; + if confidence > confidence_threshold { + let mut class_index = 0; + for i in 0..nclasses { + if pred[4 + i] > pred[4 + class_index] { + class_index = i + } + } + if pred[class_index + 4] > 0. { + let keypoints = (0..17) + .map(|i| (pred[5 + 3 * i], pred[3 * i + 6], pred[3 * i + 7])) + .collect::>(); + let bbox = Bbox { + xmin: pred[0] - pred[2] / 2., + ymin: pred[1] - pred[3] / 2., + xmax: pred[0] + pred[2] / 2., + ymax: pred[1] + pred[3] / 2., + confidence, + keypoints, + }; + bboxes[class_index].push(bbox) + } + } + } + + non_maximum_suppression(&mut bboxes, nms_threshold); + + // Annotate the original image and print boxes information. + let (initial_h, initial_w) = (img.height(), img.width()); + let w_ratio = initial_w as f32 / w as f32; + let h_ratio = initial_h as f32 / h as f32; + let mut img = img.to_rgb8(); + for (class_index, bboxes_for_class) in bboxes.iter().enumerate() { + for b in bboxes_for_class.iter() { + println!( + "{}: {:?}", + candle_examples::coco_classes::NAMES[class_index], + b + ); + let xmin = ((b.xmin * w_ratio) as u32).clamp(0, initial_w - 1); + let ymin = ((b.ymin * h_ratio) as u32).clamp(0, initial_h - 1); + let xmax = ((b.xmax * w_ratio) as u32).clamp(0, initial_w - 1); + let ymax = ((b.ymax * h_ratio) as u32).clamp(0, initial_h - 1); + draw_rect(&mut img, xmin, xmax, ymin, ymax); + for (x, y, z) in b.keypoints.iter() { + if z < &0.6 { + continue; + } + let x = x * w_ratio; + let y = y * w_ratio; + for dx in -2..3 { + for dy in -2..3 { + let x = ((x + dx as f32) as u32).clamp(0, initial_w - 1); + let y = ((y + dy as f32) as u32).clamp(0, initial_h - 1); + let pixel = img.get_pixel_mut(x, y); + *pixel = image::Rgb([0, 255, 0]); + } + } + } + } + } + Ok(DynamicImage::ImageRgb8(img)) +} + #[derive(Clone, Copy, ValueEnum, Debug)] enum Which { N, @@ -796,9 +876,15 @@ enum Which { X, } +#[derive(Clone, Copy, ValueEnum, Debug)] +enum YoloTask { + Detect, + Pose, +} + #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] -struct Args { +pub struct Args { /// Model weights, in safetensors format. #[arg(long)] model: Option, @@ -816,6 +902,10 @@ struct Args { /// Threshold for non-maximum suppression. #[arg(long, default_value_t = 0.45)] nms_threshold: f32, + + /// The task to be run. + #[arg(long, default_value = "detect")] + task: YoloTask, } impl Args { @@ -825,23 +915,71 @@ impl Args { None => { let api = hf_hub::api::sync::Api::new()?; let api = api.model("lmz/candle-yolo-v8".to_string()); - let filename = match self.which { - Which::N => "yolov8n.safetensors", - Which::S => "yolov8s.safetensors", - Which::M => "yolov8m.safetensors", - Which::L => "yolov8l.safetensors", - Which::X => "yolov8x.safetensors", + let size = match self.which { + Which::N => "n", + Which::S => "s", + Which::M => "m", + Which::L => "l", + Which::X => "x", }; - api.get(filename)? + let task = match self.task { + YoloTask::Pose => "-pose", + YoloTask::Detect => "", + }; + api.get(&format!("yolov8{size}{task}.safetensors"))? } }; Ok(path) } } -pub fn main() -> anyhow::Result<()> { - let args = Args::parse(); +pub trait Task: Module + Sized { + fn load(vb: VarBuilder, multiples: Multiples) -> Result; + fn report( + pred: &Tensor, + img: DynamicImage, + w: usize, + h: usize, + confidence_threshold: f32, + nms_threshold: f32, + ) -> Result; +} +impl Task for YoloV8 { + fn load(vb: VarBuilder, multiples: Multiples) -> Result { + YoloV8::load(vb, multiples, /* num_classes=*/ 80) + } + + fn report( + pred: &Tensor, + img: DynamicImage, + w: usize, + h: usize, + confidence_threshold: f32, + nms_threshold: f32, + ) -> Result { + report_detect(pred, img, w, h, confidence_threshold, nms_threshold) + } +} + +impl Task for YoloV8Pose { + fn load(vb: VarBuilder, multiples: Multiples) -> Result { + YoloV8Pose::load(vb, multiples, /* num_classes=*/ 1, (17, 3)) + } + + fn report( + pred: &Tensor, + img: DynamicImage, + w: usize, + h: usize, + confidence_threshold: f32, + nms_threshold: f32, + ) -> Result { + report_pose(pred, img, w, h, confidence_threshold, nms_threshold) + } +} + +pub fn run(args: Args) -> anyhow::Result<()> { // Create the model and load the weights from the file. let multiples = match args.which { Which::N => Multiples::n(), @@ -854,8 +992,7 @@ pub fn main() -> anyhow::Result<()> { let weights = unsafe { candle::safetensors::MmapedFile::new(model)? }; let weights = weights.deserialize()?; let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &Device::Cpu); - let model = YoloV8::load(vb, multiples, /* num_classes=*/ 80)?; - // let model = YoloV8Pose::load(vb, multiples, /* num_classes=*/ 1, (17, 3))?; + let model = T::load(vb, multiples)?; println!("model loaded"); for image_name in args.images.iter() { println!("processing {image_name}"); @@ -892,7 +1029,7 @@ pub fn main() -> anyhow::Result<()> { let image_t = (image_t.unsqueeze(0)?.to_dtype(DType::F32)? * (1. / 255.))?; let predictions = model.forward(&image_t)?.squeeze(0)?; println!("generated predictions {predictions:?}"); - let image_t = report( + let image_t = T::report( &predictions, original_image, width, @@ -907,3 +1044,12 @@ pub fn main() -> anyhow::Result<()> { Ok(()) } + +pub fn main() -> anyhow::Result<()> { + let args = Args::parse(); + match args.task { + YoloTask::Detect => run::(args)?, + YoloTask::Pose => run::(args)?, + } + Ok(()) +} diff --git a/candle-examples/src/object_detection.rs b/candle-examples/src/object_detection.rs index 78fa933d..7352a99f 100644 --- a/candle-examples/src/object_detection.rs +++ b/candle-examples/src/object_detection.rs @@ -1,11 +1,12 @@ /// A bounding box around an object. -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone)] pub struct Bbox { pub xmin: f32, pub ymin: f32, pub xmax: f32, pub ymax: f32, pub confidence: f32, + pub keypoints: Vec<(f32, f32, f32)>, } /// Intersection over union of two bounding boxes.