Add wasm support for yolo-v8 pose detection. (#630)

* Add wasm support for yolo-v8 pose detection.

* Better bbox handling.

* Add the pose model in the wasm example lib.
This commit is contained in:
Laurent Mazare
2023-08-27 19:49:24 +01:00
committed by GitHub
parent 72ebb12bca
commit 24dda44c27
3 changed files with 315 additions and 40 deletions

View File

@ -1,6 +1,7 @@
use candle_wasm_example_yolo::coco_classes;
use candle_wasm_example_yolo::model::Bbox;
use candle_wasm_example_yolo::worker::Model as M;
use candle_wasm_example_yolo::worker::ModelPose as P;
use wasm_bindgen::prelude::*;
#[wasm_bindgen]
@ -26,18 +27,9 @@ impl Model {
let bboxes = self.inner.run(image, conf_threshold, iou_threshold)?;
let mut detections: Vec<(String, Bbox)> = vec![];
for (class_index, bboxes_for_class) in bboxes.iter().enumerate() {
for b in bboxes_for_class.iter() {
detections.push((
coco_classes::NAMES[class_index].to_string(),
Bbox {
xmin: b.xmin,
ymin: b.ymin,
xmax: b.xmax,
ymax: b.ymax,
confidence: b.confidence,
},
));
for (class_index, bboxes_for_class) in bboxes.into_iter().enumerate() {
for b in bboxes_for_class.into_iter() {
detections.push((coco_classes::NAMES[class_index].to_string(), b));
}
}
let json = serde_json::to_string(&detections)?;
@ -45,4 +37,30 @@ impl Model {
}
}
#[wasm_bindgen]
pub struct ModelPose {
inner: P,
}
#[wasm_bindgen]
impl ModelPose {
#[wasm_bindgen(constructor)]
pub fn new(data: Vec<u8>, model_size: &str) -> Result<ModelPose, JsError> {
let inner = P::load_(&data, model_size)?;
Ok(Self { inner })
}
#[wasm_bindgen]
pub fn run(
&self,
image: Vec<u8>,
conf_threshold: f32,
iou_threshold: f32,
) -> Result<String, JsError> {
let bboxes = self.inner.run(image, conf_threshold, iou_threshold)?;
let json = serde_json::to_string(&bboxes)?;
Ok(json)
}
}
fn main() {}

View File

@ -445,6 +445,13 @@ struct DetectionHead {
no: usize,
}
#[derive(Debug)]
struct PoseHead {
detect: DetectionHead,
cv4: [(ConvBlock, ConvBlock, Conv2d); 3],
kpt: (usize, usize),
}
fn make_anchors(
xs0: &Tensor,
xs1: &Tensor,
@ -475,6 +482,13 @@ fn make_anchors(
let stride_tensor = Tensor::cat(stride_tensor.as_slice(), 0)?.unsqueeze(1)?;
Ok((anchor_points, stride_tensor))
}
struct DetectionHeadOut {
pred: Tensor,
anchors: Tensor,
strides: Tensor,
}
fn dist2bbox(distance: &Tensor, anchor_points: &Tensor) -> Result<Tensor> {
let chunks = distance.chunk(2, 1)?;
let lt = &chunks[0];
@ -536,7 +550,7 @@ impl DetectionHead {
Ok((block0, block1, conv))
}
fn forward(&self, xs0: &Tensor, xs1: &Tensor, xs2: &Tensor) -> Result<Tensor> {
fn forward(&self, xs0: &Tensor, xs1: &Tensor, xs2: &Tensor) -> Result<DetectionHeadOut> {
let forward_cv = |xs, i: usize| {
let xs_2 = self.cv2[i].0.forward(xs)?;
let xs_2 = self.cv2[i].1.forward(&xs_2)?;
@ -552,7 +566,7 @@ impl DetectionHead {
let xs2 = forward_cv(xs2, 2)?;
let (anchors, strides) = make_anchors(&xs0, &xs1, &xs2, (8, 16, 32), 0.5)?;
let anchors = anchors.transpose(0, 1)?;
let anchors = anchors.transpose(0, 1)?.unsqueeze(0)?;
let strides = strides.transpose(0, 1)?;
let reshape = |xs: &Tensor| {
@ -568,9 +582,70 @@ impl DetectionHead {
let box_ = x_cat.i((.., ..self.ch * 4))?;
let cls = x_cat.i((.., self.ch * 4..))?;
let dbox = dist2bbox(&self.dfl.forward(&box_)?, &anchors.unsqueeze(0)?)?;
let dbox = dist2bbox(&self.dfl.forward(&box_)?, &anchors)?;
let dbox = dbox.broadcast_mul(&strides)?;
Tensor::cat(&[dbox, candle_nn::ops::sigmoid(&cls)?], 1)
let pred = Tensor::cat(&[dbox, candle_nn::ops::sigmoid(&cls)?], 1)?;
Ok(DetectionHeadOut {
pred,
anchors,
strides,
})
}
}
impl PoseHead {
// kpt: keypoints, (17, 3)
// nc: num-classes, 80
fn load(
vb: VarBuilder,
nc: usize,
kpt: (usize, usize),
filters: (usize, usize, usize),
) -> Result<Self> {
let detect = DetectionHead::load(vb.clone(), nc, filters)?;
let nk = kpt.0 * kpt.1;
let c4 = usize::max(filters.0 / 4, nk);
let cv4 = [
Self::load_cv4(vb.pp("cv4.0"), c4, nk, filters.0)?,
Self::load_cv4(vb.pp("cv4.1"), c4, nk, filters.1)?,
Self::load_cv4(vb.pp("cv4.2"), c4, nk, filters.2)?,
];
Ok(Self { detect, cv4, kpt })
}
fn load_cv4(
vb: VarBuilder,
c1: usize,
nc: usize,
filter: usize,
) -> Result<(ConvBlock, ConvBlock, Conv2d)> {
let block0 = ConvBlock::load(vb.pp("0"), filter, c1, 3, 1, None)?;
let block1 = ConvBlock::load(vb.pp("1"), c1, c1, 3, 1, None)?;
let conv = conv2d(c1, nc, 1, Default::default(), vb.pp("2"))?;
Ok((block0, block1, conv))
}
fn forward(&self, xs0: &Tensor, xs1: &Tensor, xs2: &Tensor) -> Result<Tensor> {
let d = self.detect.forward(xs0, xs1, xs2)?;
let forward_cv = |xs: &Tensor, i: usize| {
let (b_sz, _, h, w) = xs.dims4()?;
let xs = self.cv4[i].0.forward(xs)?;
let xs = self.cv4[i].1.forward(&xs)?;
let xs = self.cv4[i].2.forward(&xs)?;
xs.reshape((b_sz, self.kpt.0 * self.kpt.1, h * w))
};
let xs0 = forward_cv(xs0, 0)?;
let xs1 = forward_cv(xs1, 1)?;
let xs2 = forward_cv(xs2, 2)?;
let xs = Tensor::cat(&[xs0, xs1, xs2], D::Minus1)?;
let (b_sz, _nk, hw) = xs.dims3()?;
let xs = xs.reshape((b_sz, self.kpt.0, self.kpt.1, hw))?;
let ys01 = ((xs.i((.., .., 0..2))? * 2.)?.broadcast_add(&d.anchors)? - 0.5)?
.broadcast_mul(&d.strides)?;
let ys2 = candle_nn::ops::sigmoid(&xs.i((.., .., 2..3))?)?;
let ys = Tensor::cat(&[ys01, ys2], 2)?.flatten(1, 2)?;
Tensor::cat(&[d.pred, ys], 1)
}
}
@ -591,6 +666,35 @@ impl YoloV8 {
}
impl Module for YoloV8 {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let (xs1, xs2, xs3) = self.net.forward(xs)?;
let (xs1, xs2, xs3) = self.fpn.forward(&xs1, &xs2, &xs3)?;
Ok(self.head.forward(&xs1, &xs2, &xs3)?.pred)
}
}
#[derive(Debug)]
pub struct YoloV8Pose {
net: DarkNet,
fpn: YoloV8Neck,
head: PoseHead,
}
impl YoloV8Pose {
pub fn load(
vb: VarBuilder,
m: Multiples,
num_classes: usize,
kpt: (usize, usize),
) -> Result<Self> {
let net = DarkNet::load(vb.pp("net"), m)?;
let fpn = YoloV8Neck::load(vb.pp("fpn"), m)?;
let head = PoseHead::load(vb.pp("head"), num_classes, kpt, m.filters())?;
Ok(Self { net, fpn, head })
}
}
impl Module for YoloV8Pose {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let (xs1, xs2, xs3) = self.net.forward(xs)?;
let (xs1, xs2, xs3) = self.fpn.forward(&xs1, &xs2, &xs3)?;
@ -598,13 +702,21 @@ impl Module for YoloV8 {
}
}
#[derive(Debug, Clone, Copy, serde::Serialize, serde::Deserialize)]
#[derive(Debug, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)]
pub struct KeyPoint {
pub x: f32,
pub y: f32,
pub mask: f32,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct Bbox {
pub xmin: f32,
pub ymin: f32,
pub xmax: f32,
pub ymax: f32,
pub confidence: f32,
pub keypoints: Vec<KeyPoint>,
}
// Intersection over union of two bounding boxes.
@ -619,7 +731,7 @@ fn iou(b1: &Bbox, b2: &Bbox) -> f32 {
i_area / (b1_area + b2_area - i_area)
}
pub fn report(
pub fn report_detect(
pred: &Tensor,
img: DynamicImage,
w: usize,
@ -651,31 +763,15 @@ pub fn report(
xmax: pred[0] + pred[2] / 2.,
ymax: pred[1] + pred[3] / 2.,
confidence,
keypoints: vec![],
};
bboxes[class_index].push(bbox)
}
}
}
// Perform non-maximum suppression.
for bboxes_for_class in bboxes.iter_mut() {
bboxes_for_class.sort_by(|b1, b2| b2.confidence.partial_cmp(&b1.confidence).unwrap());
let mut current_index = 0;
for index in 0..bboxes_for_class.len() {
let mut drop = false;
for prev_index in 0..current_index {
let iou = iou(&bboxes_for_class[prev_index], &bboxes_for_class[index]);
if iou > iou_threshold {
drop = true;
break;
}
}
if !drop {
bboxes_for_class.swap(current_index, index);
current_index += 1;
}
}
bboxes_for_class.truncate(current_index);
}
non_maximum_suppression(&mut bboxes, iou_threshold);
// Annotate the original image and print boxes information.
let (initial_h, initial_w) = (img.height() as f32, img.width() as f32);
let w_ratio = initial_w / w as f32;
@ -691,3 +787,84 @@ pub fn report(
}
Ok(bboxes)
}
fn non_maximum_suppression(bboxes: &mut [Vec<Bbox>], threshold: f32) {
// Perform non-maximum suppression.
for bboxes_for_class in bboxes.iter_mut() {
bboxes_for_class.sort_by(|b1, b2| b2.confidence.partial_cmp(&b1.confidence).unwrap());
let mut current_index = 0;
for index in 0..bboxes_for_class.len() {
let mut drop = false;
for prev_index in 0..current_index {
let iou = iou(&bboxes_for_class[prev_index], &bboxes_for_class[index]);
if iou > threshold {
drop = true;
break;
}
}
if !drop {
bboxes_for_class.swap(current_index, index);
current_index += 1;
}
}
bboxes_for_class.truncate(current_index);
}
}
pub fn report_pose(
pred: &Tensor,
img: DynamicImage,
w: usize,
h: usize,
confidence_threshold: f32,
nms_threshold: f32,
) -> Result<Vec<Bbox>> {
let (pred_size, npreds) = pred.dims2()?;
if pred_size != 17 * 3 + 4 + 1 {
candle::bail!("unexpected pred-size {pred_size}");
}
let mut bboxes = vec![];
// Extract the bounding boxes for which confidence is above the threshold.
for index in 0..npreds {
let pred = Vec::<f32>::try_from(pred.i((.., index))?)?;
let confidence = pred[4];
if confidence > confidence_threshold {
let keypoints = (0..17)
.map(|i| KeyPoint {
x: pred[3 * i + 5],
y: pred[3 * i + 6],
mask: pred[3 * i + 7],
})
.collect::<Vec<_>>();
let bbox = Bbox {
xmin: pred[0] - pred[2] / 2.,
ymin: pred[1] - pred[3] / 2.,
xmax: pred[0] + pred[2] / 2.,
ymax: pred[1] + pred[3] / 2.,
confidence,
keypoints,
};
bboxes.push(bbox)
}
}
let mut bboxes = vec![bboxes];
non_maximum_suppression(&mut bboxes, nms_threshold);
let mut bboxes = bboxes.into_iter().next().unwrap();
let (initial_h, initial_w) = (img.height() as f32, img.width() as f32);
let w_ratio = initial_w / w as f32;
let h_ratio = initial_h / h as f32;
for b in bboxes.iter_mut() {
crate::console_log!("detected {b:?}");
b.xmin = (b.xmin * w_ratio).clamp(0., initial_w - 1.);
b.ymin = (b.ymin * h_ratio).clamp(0., initial_h - 1.);
b.xmax = (b.xmax * w_ratio).clamp(0., initial_w - 1.);
b.ymax = (b.ymax * h_ratio).clamp(0., initial_h - 1.);
for kp in b.keypoints.iter_mut() {
kp.x = (kp.x * w_ratio).clamp(0., initial_w - 1.);
kp.y = (kp.y * h_ratio).clamp(0., initial_h - 1.);
}
}
Ok(bboxes)
}

View File

@ -1,4 +1,4 @@
use crate::model::{report, Bbox, Multiples, YoloV8};
use crate::model::{report_detect, report_pose, Bbox, Multiples, YoloV8, YoloV8Pose};
use candle::{DType, Device, Result, Tensor};
use candle_nn::{Module, VarBuilder};
use serde::{Deserialize, Serialize};
@ -81,7 +81,7 @@ impl Model {
let image_t = (image_t.unsqueeze(0)?.to_dtype(DType::F32)? * (1. / 255.))?;
let predictions = self.model.forward(&image_t)?.squeeze(0)?;
console_log!("generated predictions {predictions:?}");
let bboxes = report(
let bboxes = report_detect(
&predictions,
original_image,
width,
@ -115,6 +115,86 @@ impl Model {
}
}
pub struct ModelPose {
model: YoloV8Pose,
}
impl ModelPose {
pub fn run(
&self,
image_data: Vec<u8>,
conf_threshold: f32,
iou_threshold: f32,
) -> Result<Vec<Bbox>> {
console_log!("image data: {}", image_data.len());
let image_data = std::io::Cursor::new(image_data);
let original_image = image::io::Reader::new(image_data)
.with_guessed_format()?
.decode()
.map_err(candle::Error::wrap)?;
let (width, height) = {
let w = original_image.width() as usize;
let h = original_image.height() as usize;
if w < h {
let w = w * 640 / h;
// Sizes have to be divisible by 32.
(w / 32 * 32, 640)
} else {
let h = h * 640 / w;
(640, h / 32 * 32)
}
};
let image_t = {
let img = original_image.resize_exact(
width as u32,
height as u32,
image::imageops::FilterType::CatmullRom,
);
let data = img.to_rgb8().into_raw();
Tensor::from_vec(
data,
(img.height() as usize, img.width() as usize, 3),
&Device::Cpu,
)?
.permute((2, 0, 1))?
};
let image_t = (image_t.unsqueeze(0)?.to_dtype(DType::F32)? * (1. / 255.))?;
let predictions = self.model.forward(&image_t)?.squeeze(0)?;
console_log!("generated predictions {predictions:?}");
let bboxes = report_pose(
&predictions,
original_image,
width,
height,
conf_threshold,
iou_threshold,
)?;
Ok(bboxes)
}
pub fn load_(weights: &[u8], model_size: &str) -> Result<Self> {
let multiples = match model_size {
"n" => Multiples::n(),
"s" => Multiples::s(),
"m" => Multiples::m(),
"l" => Multiples::l(),
"x" => Multiples::x(),
_ => Err(candle::Error::Msg(
"invalid model size: must be n, s, m, l or x".to_string(),
))?,
};
let dev = &Device::Cpu;
let weights = safetensors::tensor::SafeTensors::deserialize(weights)?;
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, dev);
let model = YoloV8Pose::load(vb, multiples, 1, (17, 3))?;
Ok(Self { model })
}
pub fn load(md: ModelData) -> Result<Self> {
Self::load_(&md.weights, &md.model_size.to_string())
}
}
pub struct Worker {
link: WorkerLink<Self>,
model: Option<Model>,