Automatic mask generation (#779)

* A few more contiguous fixes for cuda.

* Mask generation.

* Generic bbox.

* Generate all the masks.
This commit is contained in:
Laurent Mazare
2023-09-08 19:11:34 +01:00
committed by GitHub
parent 158ff3c609
commit 0906acab91
7 changed files with 125 additions and 26 deletions

View File

@ -64,7 +64,7 @@ pub fn report_detect(
let (pred_size, npreds) = pred.dims2()?;
let nclasses = pred_size - 4;
// The bounding boxes grouped by (maximum) class index.
let mut bboxes: Vec<Vec<Bbox>> = (0..nclasses).map(|_| vec![]).collect();
let mut bboxes: Vec<Vec<Bbox<Vec<KeyPoint>>>> = (0..nclasses).map(|_| vec![]).collect();
// Extract the bounding boxes for which confidence is above the threshold.
for index in 0..npreds {
let pred = Vec::<f32>::try_from(pred.i((.., index))?)?;
@ -83,7 +83,7 @@ pub fn report_detect(
xmax: pred[0] + pred[2] / 2.,
ymax: pred[1] + pred[3] / 2.,
confidence,
keypoints: vec![],
data: vec![],
};
bboxes[class_index].push(bbox)
}
@ -176,7 +176,7 @@ pub fn report_pose(
xmax: pred[0] + pred[2] / 2.,
ymax: pred[1] + pred[3] / 2.,
confidence,
keypoints,
data: keypoints,
};
bboxes.push(bbox)
}
@ -204,7 +204,7 @@ pub fn report_pose(
image::Rgb([255, 0, 0]),
);
}
for kp in b.keypoints.iter() {
for kp in b.data.iter() {
if kp.mask < 0.6 {
continue;
}
@ -219,8 +219,8 @@ pub fn report_pose(
}
for &(idx1, idx2) in KP_CONNECTIONS.iter() {
let kp1 = &b.keypoints[idx1];
let kp2 = &b.keypoints[idx2];
let kp1 = &b.data[idx1];
let kp2 = &b.data[idx2];
if kp1.mask < 0.6 || kp2.mask < 0.6 {
continue;
}