mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Automatic mask generation (#779)
* A few more contiguous fixes for cuda. * Mask generation. * Generic bbox. * Generate all the masks.
This commit is contained in:
@ -64,7 +64,7 @@ pub fn report_detect(
|
||||
let (pred_size, npreds) = pred.dims2()?;
|
||||
let nclasses = pred_size - 4;
|
||||
// The bounding boxes grouped by (maximum) class index.
|
||||
let mut bboxes: Vec<Vec<Bbox>> = (0..nclasses).map(|_| vec![]).collect();
|
||||
let mut bboxes: Vec<Vec<Bbox<Vec<KeyPoint>>>> = (0..nclasses).map(|_| vec![]).collect();
|
||||
// Extract the bounding boxes for which confidence is above the threshold.
|
||||
for index in 0..npreds {
|
||||
let pred = Vec::<f32>::try_from(pred.i((.., index))?)?;
|
||||
@ -83,7 +83,7 @@ pub fn report_detect(
|
||||
xmax: pred[0] + pred[2] / 2.,
|
||||
ymax: pred[1] + pred[3] / 2.,
|
||||
confidence,
|
||||
keypoints: vec![],
|
||||
data: vec![],
|
||||
};
|
||||
bboxes[class_index].push(bbox)
|
||||
}
|
||||
@ -176,7 +176,7 @@ pub fn report_pose(
|
||||
xmax: pred[0] + pred[2] / 2.,
|
||||
ymax: pred[1] + pred[3] / 2.,
|
||||
confidence,
|
||||
keypoints,
|
||||
data: keypoints,
|
||||
};
|
||||
bboxes.push(bbox)
|
||||
}
|
||||
@ -204,7 +204,7 @@ pub fn report_pose(
|
||||
image::Rgb([255, 0, 0]),
|
||||
);
|
||||
}
|
||||
for kp in b.keypoints.iter() {
|
||||
for kp in b.data.iter() {
|
||||
if kp.mask < 0.6 {
|
||||
continue;
|
||||
}
|
||||
@ -219,8 +219,8 @@ pub fn report_pose(
|
||||
}
|
||||
|
||||
for &(idx1, idx2) in KP_CONNECTIONS.iter() {
|
||||
let kp1 = &b.keypoints[idx1];
|
||||
let kp2 = &b.keypoints[idx2];
|
||||
let kp1 = &b.data[idx1];
|
||||
let kp2 = &b.data[idx2];
|
||||
if kp1.mask < 0.6 || kp2.mask < 0.6 {
|
||||
continue;
|
||||
}
|
||||
|
Reference in New Issue
Block a user