mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Cleanup the pose reporting code. (#605)
This commit is contained in:
@ -154,89 +154,82 @@ pub fn report_pose(
|
|||||||
nms_threshold: f32,
|
nms_threshold: f32,
|
||||||
) -> Result<DynamicImage> {
|
) -> Result<DynamicImage> {
|
||||||
let (pred_size, npreds) = pred.dims2()?;
|
let (pred_size, npreds) = pred.dims2()?;
|
||||||
let nclasses = pred_size - 4;
|
if pred_size != 17 * 3 + 4 + 1 {
|
||||||
// The bounding boxes grouped by (maximum) class index.
|
candle::bail!("unexpected pred-size {pred_size}");
|
||||||
let mut bboxes: Vec<Vec<Bbox>> = (0..nclasses).map(|_| vec![]).collect();
|
}
|
||||||
|
let mut bboxes = vec![];
|
||||||
// Extract the bounding boxes for which confidence is above the threshold.
|
// Extract the bounding boxes for which confidence is above the threshold.
|
||||||
for index in 0..npreds {
|
for index in 0..npreds {
|
||||||
let pred = Vec::<f32>::try_from(pred.i((.., index))?)?;
|
let pred = Vec::<f32>::try_from(pred.i((.., index))?)?;
|
||||||
let confidence = pred[4];
|
let confidence = pred[4];
|
||||||
if confidence > confidence_threshold {
|
if confidence > confidence_threshold {
|
||||||
let mut class_index = 0;
|
let keypoints = (0..17)
|
||||||
for i in 0..nclasses {
|
.map(|i| KeyPoint {
|
||||||
if pred[4 + i] > pred[4 + class_index] {
|
x: pred[3 * i + 5],
|
||||||
class_index = i
|
y: pred[3 * i + 6],
|
||||||
}
|
mask: pred[3 * i + 7],
|
||||||
}
|
})
|
||||||
if pred[class_index + 4] > 0. {
|
.collect::<Vec<_>>();
|
||||||
let keypoints = (0..17)
|
let bbox = Bbox {
|
||||||
.map(|i| KeyPoint {
|
xmin: pred[0] - pred[2] / 2.,
|
||||||
x: pred[5 + 3 * i],
|
ymin: pred[1] - pred[3] / 2.,
|
||||||
y: pred[3 * i + 6],
|
xmax: pred[0] + pred[2] / 2.,
|
||||||
mask: pred[3 * i + 7],
|
ymax: pred[1] + pred[3] / 2.,
|
||||||
})
|
confidence,
|
||||||
.collect::<Vec<_>>();
|
keypoints,
|
||||||
let bbox = Bbox {
|
};
|
||||||
xmin: pred[0] - pred[2] / 2.,
|
bboxes.push(bbox)
|
||||||
ymin: pred[1] - pred[3] / 2.,
|
|
||||||
xmax: pred[0] + pred[2] / 2.,
|
|
||||||
ymax: pred[1] + pred[3] / 2.,
|
|
||||||
confidence,
|
|
||||||
keypoints,
|
|
||||||
};
|
|
||||||
bboxes[class_index].push(bbox)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let mut bboxes = vec![bboxes];
|
||||||
non_maximum_suppression(&mut bboxes, nms_threshold);
|
non_maximum_suppression(&mut bboxes, nms_threshold);
|
||||||
|
let bboxes = &bboxes[0];
|
||||||
|
|
||||||
// Annotate the original image and print boxes information.
|
// Annotate the original image and print boxes information.
|
||||||
let (initial_h, initial_w) = (img.height(), img.width());
|
let (initial_h, initial_w) = (img.height(), img.width());
|
||||||
let w_ratio = initial_w as f32 / w as f32;
|
let w_ratio = initial_w as f32 / w as f32;
|
||||||
let h_ratio = initial_h as f32 / h as f32;
|
let h_ratio = initial_h as f32 / h as f32;
|
||||||
let mut img = img.to_rgb8();
|
let mut img = img.to_rgb8();
|
||||||
for bboxes_for_class in bboxes.iter() {
|
for b in bboxes.iter() {
|
||||||
for b in bboxes_for_class.iter() {
|
println!("{b:?}");
|
||||||
println!("{b:?}");
|
let xmin = (b.xmin * w_ratio) as i32;
|
||||||
let xmin = (b.xmin * w_ratio) as i32;
|
let ymin = (b.ymin * h_ratio) as i32;
|
||||||
let ymin = (b.ymin * h_ratio) as i32;
|
let dx = (b.xmax - b.xmin) * w_ratio;
|
||||||
let dx = (b.xmax - b.xmin) * w_ratio;
|
let dy = (b.ymax - b.ymin) * h_ratio;
|
||||||
let dy = (b.ymax - b.ymin) * h_ratio;
|
if dx >= 0. && dy >= 0. {
|
||||||
if dx >= 0. && dy >= 0. {
|
imageproc::drawing::draw_hollow_rect_mut(
|
||||||
imageproc::drawing::draw_hollow_rect_mut(
|
&mut img,
|
||||||
&mut img,
|
imageproc::rect::Rect::at(xmin, ymin).of_size(dx as u32, dy as u32),
|
||||||
imageproc::rect::Rect::at(xmin, ymin).of_size(dx as u32, dy as u32),
|
image::Rgb([255, 0, 0]),
|
||||||
image::Rgb([255, 0, 0]),
|
);
|
||||||
);
|
}
|
||||||
}
|
for kp in b.keypoints.iter() {
|
||||||
for kp in b.keypoints.iter() {
|
if kp.mask < 0.6 {
|
||||||
if kp.mask < 0.6 {
|
continue;
|
||||||
continue;
|
|
||||||
}
|
|
||||||
let x = (kp.x * w_ratio) as i32;
|
|
||||||
let y = (kp.y * h_ratio) as i32;
|
|
||||||
imageproc::drawing::draw_filled_circle_mut(
|
|
||||||
&mut img,
|
|
||||||
(x, y),
|
|
||||||
2,
|
|
||||||
image::Rgb([0, 255, 0]),
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
let x = (kp.x * w_ratio) as i32;
|
||||||
|
let y = (kp.y * h_ratio) as i32;
|
||||||
|
imageproc::drawing::draw_filled_circle_mut(
|
||||||
|
&mut img,
|
||||||
|
(x, y),
|
||||||
|
2,
|
||||||
|
image::Rgb([0, 255, 0]),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
for &(idx1, idx2) in KP_CONNECTIONS.iter() {
|
for &(idx1, idx2) in KP_CONNECTIONS.iter() {
|
||||||
let kp1 = &b.keypoints[idx1];
|
let kp1 = &b.keypoints[idx1];
|
||||||
let kp2 = &b.keypoints[idx2];
|
let kp2 = &b.keypoints[idx2];
|
||||||
if kp1.mask < 0.6 || kp2.mask < 0.6 {
|
if kp1.mask < 0.6 || kp2.mask < 0.6 {
|
||||||
continue;
|
continue;
|
||||||
}
|
|
||||||
imageproc::drawing::draw_line_segment_mut(
|
|
||||||
&mut img,
|
|
||||||
(kp1.x * w_ratio, kp1.y * h_ratio),
|
|
||||||
(kp2.x * w_ratio, kp2.y * h_ratio),
|
|
||||||
image::Rgb([255, 255, 0]),
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
imageproc::drawing::draw_line_segment_mut(
|
||||||
|
&mut img,
|
||||||
|
(kp1.x * w_ratio, kp1.y * h_ratio),
|
||||||
|
(kp2.x * w_ratio, kp2.y * h_ratio),
|
||||||
|
image::Rgb([255, 255, 0]),
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(DynamicImage::ImageRgb8(img))
|
Ok(DynamicImage::ImageRgb8(img))
|
||||||
|
Reference in New Issue
Block a user