Cleanup the pose reporting code. (#605)

This commit is contained in:
Laurent Mazare
2023-08-25 16:49:21 +01:00
committed by GitHub
parent 9c8d6dbc2a
commit ca6c050b04

View File

@ -154,89 +154,82 @@ pub fn report_pose(
nms_threshold: f32, nms_threshold: f32,
) -> Result<DynamicImage> { ) -> Result<DynamicImage> {
let (pred_size, npreds) = pred.dims2()?; let (pred_size, npreds) = pred.dims2()?;
let nclasses = pred_size - 4; if pred_size != 17 * 3 + 4 + 1 {
// The bounding boxes grouped by (maximum) class index. candle::bail!("unexpected pred-size {pred_size}");
let mut bboxes: Vec<Vec<Bbox>> = (0..nclasses).map(|_| vec![]).collect(); }
let mut bboxes = vec![];
// Extract the bounding boxes for which confidence is above the threshold. // Extract the bounding boxes for which confidence is above the threshold.
for index in 0..npreds { for index in 0..npreds {
let pred = Vec::<f32>::try_from(pred.i((.., index))?)?; let pred = Vec::<f32>::try_from(pred.i((.., index))?)?;
let confidence = pred[4]; let confidence = pred[4];
if confidence > confidence_threshold { if confidence > confidence_threshold {
let mut class_index = 0; let keypoints = (0..17)
for i in 0..nclasses { .map(|i| KeyPoint {
if pred[4 + i] > pred[4 + class_index] { x: pred[3 * i + 5],
class_index = i y: pred[3 * i + 6],
} mask: pred[3 * i + 7],
} })
if pred[class_index + 4] > 0. { .collect::<Vec<_>>();
let keypoints = (0..17) let bbox = Bbox {
.map(|i| KeyPoint { xmin: pred[0] - pred[2] / 2.,
x: pred[5 + 3 * i], ymin: pred[1] - pred[3] / 2.,
y: pred[3 * i + 6], xmax: pred[0] + pred[2] / 2.,
mask: pred[3 * i + 7], ymax: pred[1] + pred[3] / 2.,
}) confidence,
.collect::<Vec<_>>(); keypoints,
let bbox = Bbox { };
xmin: pred[0] - pred[2] / 2., bboxes.push(bbox)
ymin: pred[1] - pred[3] / 2.,
xmax: pred[0] + pred[2] / 2.,
ymax: pred[1] + pred[3] / 2.,
confidence,
keypoints,
};
bboxes[class_index].push(bbox)
}
} }
} }
let mut bboxes = vec![bboxes];
non_maximum_suppression(&mut bboxes, nms_threshold); non_maximum_suppression(&mut bboxes, nms_threshold);
let bboxes = &bboxes[0];
// Annotate the original image and print boxes information. // Annotate the original image and print boxes information.
let (initial_h, initial_w) = (img.height(), img.width()); let (initial_h, initial_w) = (img.height(), img.width());
let w_ratio = initial_w as f32 / w as f32; let w_ratio = initial_w as f32 / w as f32;
let h_ratio = initial_h as f32 / h as f32; let h_ratio = initial_h as f32 / h as f32;
let mut img = img.to_rgb8(); let mut img = img.to_rgb8();
for bboxes_for_class in bboxes.iter() { for b in bboxes.iter() {
for b in bboxes_for_class.iter() { println!("{b:?}");
println!("{b:?}"); let xmin = (b.xmin * w_ratio) as i32;
let xmin = (b.xmin * w_ratio) as i32; let ymin = (b.ymin * h_ratio) as i32;
let ymin = (b.ymin * h_ratio) as i32; let dx = (b.xmax - b.xmin) * w_ratio;
let dx = (b.xmax - b.xmin) * w_ratio; let dy = (b.ymax - b.ymin) * h_ratio;
let dy = (b.ymax - b.ymin) * h_ratio; if dx >= 0. && dy >= 0. {
if dx >= 0. && dy >= 0. { imageproc::drawing::draw_hollow_rect_mut(
imageproc::drawing::draw_hollow_rect_mut( &mut img,
&mut img, imageproc::rect::Rect::at(xmin, ymin).of_size(dx as u32, dy as u32),
imageproc::rect::Rect::at(xmin, ymin).of_size(dx as u32, dy as u32), image::Rgb([255, 0, 0]),
image::Rgb([255, 0, 0]), );
); }
} for kp in b.keypoints.iter() {
for kp in b.keypoints.iter() { if kp.mask < 0.6 {
if kp.mask < 0.6 { continue;
continue;
}
let x = (kp.x * w_ratio) as i32;
let y = (kp.y * h_ratio) as i32;
imageproc::drawing::draw_filled_circle_mut(
&mut img,
(x, y),
2,
image::Rgb([0, 255, 0]),
);
} }
let x = (kp.x * w_ratio) as i32;
let y = (kp.y * h_ratio) as i32;
imageproc::drawing::draw_filled_circle_mut(
&mut img,
(x, y),
2,
image::Rgb([0, 255, 0]),
);
}
for &(idx1, idx2) in KP_CONNECTIONS.iter() { for &(idx1, idx2) in KP_CONNECTIONS.iter() {
let kp1 = &b.keypoints[idx1]; let kp1 = &b.keypoints[idx1];
let kp2 = &b.keypoints[idx2]; let kp2 = &b.keypoints[idx2];
if kp1.mask < 0.6 || kp2.mask < 0.6 { if kp1.mask < 0.6 || kp2.mask < 0.6 {
continue; continue;
}
imageproc::drawing::draw_line_segment_mut(
&mut img,
(kp1.x * w_ratio, kp1.y * h_ratio),
(kp2.x * w_ratio, kp2.y * h_ratio),
image::Rgb([255, 255, 0]),
);
} }
imageproc::drawing::draw_line_segment_mut(
&mut img,
(kp1.x * w_ratio, kp1.y * h_ratio),
(kp2.x * w_ratio, kp2.y * h_ratio),
image::Rgb([255, 255, 0]),
);
} }
} }
Ok(DynamicImage::ImageRgb8(img)) Ok(DynamicImage::ImageRgb8(img))