diff --git a/candle-examples/examples/yolo-v8/main.rs b/candle-examples/examples/yolo-v8/main.rs index ab047304..d5c5ac1c 100644 --- a/candle-examples/examples/yolo-v8/main.rs +++ b/candle-examples/examples/yolo-v8/main.rs @@ -154,89 +154,82 @@ pub fn report_pose( nms_threshold: f32, ) -> Result { let (pred_size, npreds) = pred.dims2()?; - let nclasses = pred_size - 4; - // The bounding boxes grouped by (maximum) class index. - let mut bboxes: Vec> = (0..nclasses).map(|_| vec![]).collect(); + if pred_size != 17 * 3 + 4 + 1 { + candle::bail!("unexpected pred-size {pred_size}"); + } + let mut bboxes = vec![]; // Extract the bounding boxes for which confidence is above the threshold. for index in 0..npreds { let pred = Vec::::try_from(pred.i((.., index))?)?; let confidence = pred[4]; if confidence > confidence_threshold { - let mut class_index = 0; - for i in 0..nclasses { - if pred[4 + i] > pred[4 + class_index] { - class_index = i - } - } - if pred[class_index + 4] > 0. { - let keypoints = (0..17) - .map(|i| KeyPoint { - x: pred[5 + 3 * i], - y: pred[3 * i + 6], - mask: pred[3 * i + 7], - }) - .collect::>(); - let bbox = Bbox { - xmin: pred[0] - pred[2] / 2., - ymin: pred[1] - pred[3] / 2., - xmax: pred[0] + pred[2] / 2., - ymax: pred[1] + pred[3] / 2., - confidence, - keypoints, - }; - bboxes[class_index].push(bbox) - } + let keypoints = (0..17) + .map(|i| KeyPoint { + x: pred[3 * i + 5], + y: pred[3 * i + 6], + mask: pred[3 * i + 7], + }) + .collect::>(); + let bbox = Bbox { + xmin: pred[0] - pred[2] / 2., + ymin: pred[1] - pred[3] / 2., + xmax: pred[0] + pred[2] / 2., + ymax: pred[1] + pred[3] / 2., + confidence, + keypoints, + }; + bboxes.push(bbox) } } + let mut bboxes = vec![bboxes]; non_maximum_suppression(&mut bboxes, nms_threshold); + let bboxes = &bboxes[0]; // Annotate the original image and print boxes information. let (initial_h, initial_w) = (img.height(), img.width()); let w_ratio = initial_w as f32 / w as f32; let h_ratio = initial_h as f32 / h as f32; let mut img = img.to_rgb8(); - for bboxes_for_class in bboxes.iter() { - for b in bboxes_for_class.iter() { - println!("{b:?}"); - let xmin = (b.xmin * w_ratio) as i32; - let ymin = (b.ymin * h_ratio) as i32; - let dx = (b.xmax - b.xmin) * w_ratio; - let dy = (b.ymax - b.ymin) * h_ratio; - if dx >= 0. && dy >= 0. { - imageproc::drawing::draw_hollow_rect_mut( - &mut img, - imageproc::rect::Rect::at(xmin, ymin).of_size(dx as u32, dy as u32), - image::Rgb([255, 0, 0]), - ); - } - for kp in b.keypoints.iter() { - if kp.mask < 0.6 { - continue; - } - let x = (kp.x * w_ratio) as i32; - let y = (kp.y * h_ratio) as i32; - imageproc::drawing::draw_filled_circle_mut( - &mut img, - (x, y), - 2, - image::Rgb([0, 255, 0]), - ); + for b in bboxes.iter() { + println!("{b:?}"); + let xmin = (b.xmin * w_ratio) as i32; + let ymin = (b.ymin * h_ratio) as i32; + let dx = (b.xmax - b.xmin) * w_ratio; + let dy = (b.ymax - b.ymin) * h_ratio; + if dx >= 0. && dy >= 0. { + imageproc::drawing::draw_hollow_rect_mut( + &mut img, + imageproc::rect::Rect::at(xmin, ymin).of_size(dx as u32, dy as u32), + image::Rgb([255, 0, 0]), + ); + } + for kp in b.keypoints.iter() { + if kp.mask < 0.6 { + continue; } + let x = (kp.x * w_ratio) as i32; + let y = (kp.y * h_ratio) as i32; + imageproc::drawing::draw_filled_circle_mut( + &mut img, + (x, y), + 2, + image::Rgb([0, 255, 0]), + ); + } - for &(idx1, idx2) in KP_CONNECTIONS.iter() { - let kp1 = &b.keypoints[idx1]; - let kp2 = &b.keypoints[idx2]; - if kp1.mask < 0.6 || kp2.mask < 0.6 { - continue; - } - imageproc::drawing::draw_line_segment_mut( - &mut img, - (kp1.x * w_ratio, kp1.y * h_ratio), - (kp2.x * w_ratio, kp2.y * h_ratio), - image::Rgb([255, 255, 0]), - ); + for &(idx1, idx2) in KP_CONNECTIONS.iter() { + let kp1 = &b.keypoints[idx1]; + let kp2 = &b.keypoints[idx2]; + if kp1.mask < 0.6 || kp2.mask < 0.6 { + continue; } + imageproc::drawing::draw_line_segment_mut( + &mut img, + (kp1.x * w_ratio, kp1.y * h_ratio), + (kp2.x * w_ratio, kp2.y * h_ratio), + image::Rgb([255, 255, 0]), + ); } } Ok(DynamicImage::ImageRgb8(img))