mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 03:28:50 +00:00
Cleanup the pose reporting code. (#605)
This commit is contained in:
@ -154,24 +154,18 @@ pub fn report_pose(
|
||||
nms_threshold: f32,
|
||||
) -> Result<DynamicImage> {
|
||||
let (pred_size, npreds) = pred.dims2()?;
|
||||
let nclasses = pred_size - 4;
|
||||
// The bounding boxes grouped by (maximum) class index.
|
||||
let mut bboxes: Vec<Vec<Bbox>> = (0..nclasses).map(|_| vec![]).collect();
|
||||
if pred_size != 17 * 3 + 4 + 1 {
|
||||
candle::bail!("unexpected pred-size {pred_size}");
|
||||
}
|
||||
let mut bboxes = vec![];
|
||||
// Extract the bounding boxes for which confidence is above the threshold.
|
||||
for index in 0..npreds {
|
||||
let pred = Vec::<f32>::try_from(pred.i((.., index))?)?;
|
||||
let confidence = pred[4];
|
||||
if confidence > confidence_threshold {
|
||||
let mut class_index = 0;
|
||||
for i in 0..nclasses {
|
||||
if pred[4 + i] > pred[4 + class_index] {
|
||||
class_index = i
|
||||
}
|
||||
}
|
||||
if pred[class_index + 4] > 0. {
|
||||
let keypoints = (0..17)
|
||||
.map(|i| KeyPoint {
|
||||
x: pred[5 + 3 * i],
|
||||
x: pred[3 * i + 5],
|
||||
y: pred[3 * i + 6],
|
||||
mask: pred[3 * i + 7],
|
||||
})
|
||||
@ -184,20 +178,20 @@ pub fn report_pose(
|
||||
confidence,
|
||||
keypoints,
|
||||
};
|
||||
bboxes[class_index].push(bbox)
|
||||
}
|
||||
bboxes.push(bbox)
|
||||
}
|
||||
}
|
||||
|
||||
let mut bboxes = vec![bboxes];
|
||||
non_maximum_suppression(&mut bboxes, nms_threshold);
|
||||
let bboxes = &bboxes[0];
|
||||
|
||||
// Annotate the original image and print boxes information.
|
||||
let (initial_h, initial_w) = (img.height(), img.width());
|
||||
let w_ratio = initial_w as f32 / w as f32;
|
||||
let h_ratio = initial_h as f32 / h as f32;
|
||||
let mut img = img.to_rgb8();
|
||||
for bboxes_for_class in bboxes.iter() {
|
||||
for b in bboxes_for_class.iter() {
|
||||
for b in bboxes.iter() {
|
||||
println!("{b:?}");
|
||||
let xmin = (b.xmin * w_ratio) as i32;
|
||||
let ymin = (b.ymin * h_ratio) as i32;
|
||||
@ -238,7 +232,6 @@ pub fn report_pose(
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(DynamicImage::ImageRgb8(img))
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user