mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Add wasm support for yolo-v8 pose detection. (#630)
* Add wasm support for yolo-v8 pose detection. * Better bbox handling. * Add the pose model in the wasm example lib.
This commit is contained in:
@ -1,6 +1,7 @@
|
||||
use candle_wasm_example_yolo::coco_classes;
|
||||
use candle_wasm_example_yolo::model::Bbox;
|
||||
use candle_wasm_example_yolo::worker::Model as M;
|
||||
use candle_wasm_example_yolo::worker::ModelPose as P;
|
||||
use wasm_bindgen::prelude::*;
|
||||
|
||||
#[wasm_bindgen]
|
||||
@ -26,18 +27,9 @@ impl Model {
|
||||
let bboxes = self.inner.run(image, conf_threshold, iou_threshold)?;
|
||||
let mut detections: Vec<(String, Bbox)> = vec![];
|
||||
|
||||
for (class_index, bboxes_for_class) in bboxes.iter().enumerate() {
|
||||
for b in bboxes_for_class.iter() {
|
||||
detections.push((
|
||||
coco_classes::NAMES[class_index].to_string(),
|
||||
Bbox {
|
||||
xmin: b.xmin,
|
||||
ymin: b.ymin,
|
||||
xmax: b.xmax,
|
||||
ymax: b.ymax,
|
||||
confidence: b.confidence,
|
||||
},
|
||||
));
|
||||
for (class_index, bboxes_for_class) in bboxes.into_iter().enumerate() {
|
||||
for b in bboxes_for_class.into_iter() {
|
||||
detections.push((coco_classes::NAMES[class_index].to_string(), b));
|
||||
}
|
||||
}
|
||||
let json = serde_json::to_string(&detections)?;
|
||||
@ -45,4 +37,30 @@ impl Model {
|
||||
}
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
pub struct ModelPose {
|
||||
inner: P,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl ModelPose {
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(data: Vec<u8>, model_size: &str) -> Result<ModelPose, JsError> {
|
||||
let inner = P::load_(&data, model_size)?;
|
||||
Ok(Self { inner })
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
pub fn run(
|
||||
&self,
|
||||
image: Vec<u8>,
|
||||
conf_threshold: f32,
|
||||
iou_threshold: f32,
|
||||
) -> Result<String, JsError> {
|
||||
let bboxes = self.inner.run(image, conf_threshold, iou_threshold)?;
|
||||
let json = serde_json::to_string(&bboxes)?;
|
||||
Ok(json)
|
||||
}
|
||||
}
|
||||
|
||||
fn main() {}
|
||||
|
@ -445,6 +445,13 @@ struct DetectionHead {
|
||||
no: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct PoseHead {
|
||||
detect: DetectionHead,
|
||||
cv4: [(ConvBlock, ConvBlock, Conv2d); 3],
|
||||
kpt: (usize, usize),
|
||||
}
|
||||
|
||||
fn make_anchors(
|
||||
xs0: &Tensor,
|
||||
xs1: &Tensor,
|
||||
@ -475,6 +482,13 @@ fn make_anchors(
|
||||
let stride_tensor = Tensor::cat(stride_tensor.as_slice(), 0)?.unsqueeze(1)?;
|
||||
Ok((anchor_points, stride_tensor))
|
||||
}
|
||||
|
||||
struct DetectionHeadOut {
|
||||
pred: Tensor,
|
||||
anchors: Tensor,
|
||||
strides: Tensor,
|
||||
}
|
||||
|
||||
fn dist2bbox(distance: &Tensor, anchor_points: &Tensor) -> Result<Tensor> {
|
||||
let chunks = distance.chunk(2, 1)?;
|
||||
let lt = &chunks[0];
|
||||
@ -536,7 +550,7 @@ impl DetectionHead {
|
||||
Ok((block0, block1, conv))
|
||||
}
|
||||
|
||||
fn forward(&self, xs0: &Tensor, xs1: &Tensor, xs2: &Tensor) -> Result<Tensor> {
|
||||
fn forward(&self, xs0: &Tensor, xs1: &Tensor, xs2: &Tensor) -> Result<DetectionHeadOut> {
|
||||
let forward_cv = |xs, i: usize| {
|
||||
let xs_2 = self.cv2[i].0.forward(xs)?;
|
||||
let xs_2 = self.cv2[i].1.forward(&xs_2)?;
|
||||
@ -552,7 +566,7 @@ impl DetectionHead {
|
||||
let xs2 = forward_cv(xs2, 2)?;
|
||||
|
||||
let (anchors, strides) = make_anchors(&xs0, &xs1, &xs2, (8, 16, 32), 0.5)?;
|
||||
let anchors = anchors.transpose(0, 1)?;
|
||||
let anchors = anchors.transpose(0, 1)?.unsqueeze(0)?;
|
||||
let strides = strides.transpose(0, 1)?;
|
||||
|
||||
let reshape = |xs: &Tensor| {
|
||||
@ -568,9 +582,70 @@ impl DetectionHead {
|
||||
let box_ = x_cat.i((.., ..self.ch * 4))?;
|
||||
let cls = x_cat.i((.., self.ch * 4..))?;
|
||||
|
||||
let dbox = dist2bbox(&self.dfl.forward(&box_)?, &anchors.unsqueeze(0)?)?;
|
||||
let dbox = dist2bbox(&self.dfl.forward(&box_)?, &anchors)?;
|
||||
let dbox = dbox.broadcast_mul(&strides)?;
|
||||
Tensor::cat(&[dbox, candle_nn::ops::sigmoid(&cls)?], 1)
|
||||
let pred = Tensor::cat(&[dbox, candle_nn::ops::sigmoid(&cls)?], 1)?;
|
||||
Ok(DetectionHeadOut {
|
||||
pred,
|
||||
anchors,
|
||||
strides,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl PoseHead {
|
||||
// kpt: keypoints, (17, 3)
|
||||
// nc: num-classes, 80
|
||||
fn load(
|
||||
vb: VarBuilder,
|
||||
nc: usize,
|
||||
kpt: (usize, usize),
|
||||
filters: (usize, usize, usize),
|
||||
) -> Result<Self> {
|
||||
let detect = DetectionHead::load(vb.clone(), nc, filters)?;
|
||||
let nk = kpt.0 * kpt.1;
|
||||
let c4 = usize::max(filters.0 / 4, nk);
|
||||
let cv4 = [
|
||||
Self::load_cv4(vb.pp("cv4.0"), c4, nk, filters.0)?,
|
||||
Self::load_cv4(vb.pp("cv4.1"), c4, nk, filters.1)?,
|
||||
Self::load_cv4(vb.pp("cv4.2"), c4, nk, filters.2)?,
|
||||
];
|
||||
Ok(Self { detect, cv4, kpt })
|
||||
}
|
||||
|
||||
fn load_cv4(
|
||||
vb: VarBuilder,
|
||||
c1: usize,
|
||||
nc: usize,
|
||||
filter: usize,
|
||||
) -> Result<(ConvBlock, ConvBlock, Conv2d)> {
|
||||
let block0 = ConvBlock::load(vb.pp("0"), filter, c1, 3, 1, None)?;
|
||||
let block1 = ConvBlock::load(vb.pp("1"), c1, c1, 3, 1, None)?;
|
||||
let conv = conv2d(c1, nc, 1, Default::default(), vb.pp("2"))?;
|
||||
Ok((block0, block1, conv))
|
||||
}
|
||||
|
||||
fn forward(&self, xs0: &Tensor, xs1: &Tensor, xs2: &Tensor) -> Result<Tensor> {
|
||||
let d = self.detect.forward(xs0, xs1, xs2)?;
|
||||
let forward_cv = |xs: &Tensor, i: usize| {
|
||||
let (b_sz, _, h, w) = xs.dims4()?;
|
||||
let xs = self.cv4[i].0.forward(xs)?;
|
||||
let xs = self.cv4[i].1.forward(&xs)?;
|
||||
let xs = self.cv4[i].2.forward(&xs)?;
|
||||
xs.reshape((b_sz, self.kpt.0 * self.kpt.1, h * w))
|
||||
};
|
||||
let xs0 = forward_cv(xs0, 0)?;
|
||||
let xs1 = forward_cv(xs1, 1)?;
|
||||
let xs2 = forward_cv(xs2, 2)?;
|
||||
let xs = Tensor::cat(&[xs0, xs1, xs2], D::Minus1)?;
|
||||
let (b_sz, _nk, hw) = xs.dims3()?;
|
||||
let xs = xs.reshape((b_sz, self.kpt.0, self.kpt.1, hw))?;
|
||||
|
||||
let ys01 = ((xs.i((.., .., 0..2))? * 2.)?.broadcast_add(&d.anchors)? - 0.5)?
|
||||
.broadcast_mul(&d.strides)?;
|
||||
let ys2 = candle_nn::ops::sigmoid(&xs.i((.., .., 2..3))?)?;
|
||||
let ys = Tensor::cat(&[ys01, ys2], 2)?.flatten(1, 2)?;
|
||||
Tensor::cat(&[d.pred, ys], 1)
|
||||
}
|
||||
}
|
||||
|
||||
@ -591,6 +666,35 @@ impl YoloV8 {
|
||||
}
|
||||
|
||||
impl Module for YoloV8 {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let (xs1, xs2, xs3) = self.net.forward(xs)?;
|
||||
let (xs1, xs2, xs3) = self.fpn.forward(&xs1, &xs2, &xs3)?;
|
||||
Ok(self.head.forward(&xs1, &xs2, &xs3)?.pred)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct YoloV8Pose {
|
||||
net: DarkNet,
|
||||
fpn: YoloV8Neck,
|
||||
head: PoseHead,
|
||||
}
|
||||
|
||||
impl YoloV8Pose {
|
||||
pub fn load(
|
||||
vb: VarBuilder,
|
||||
m: Multiples,
|
||||
num_classes: usize,
|
||||
kpt: (usize, usize),
|
||||
) -> Result<Self> {
|
||||
let net = DarkNet::load(vb.pp("net"), m)?;
|
||||
let fpn = YoloV8Neck::load(vb.pp("fpn"), m)?;
|
||||
let head = PoseHead::load(vb.pp("head"), num_classes, kpt, m.filters())?;
|
||||
Ok(Self { net, fpn, head })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for YoloV8Pose {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let (xs1, xs2, xs3) = self.net.forward(xs)?;
|
||||
let (xs1, xs2, xs3) = self.fpn.forward(&xs1, &xs2, &xs3)?;
|
||||
@ -598,13 +702,21 @@ impl Module for YoloV8 {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, serde::Serialize, serde::Deserialize)]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)]
|
||||
pub struct KeyPoint {
|
||||
pub x: f32,
|
||||
pub y: f32,
|
||||
pub mask: f32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||
pub struct Bbox {
|
||||
pub xmin: f32,
|
||||
pub ymin: f32,
|
||||
pub xmax: f32,
|
||||
pub ymax: f32,
|
||||
pub confidence: f32,
|
||||
pub keypoints: Vec<KeyPoint>,
|
||||
}
|
||||
|
||||
// Intersection over union of two bounding boxes.
|
||||
@ -619,7 +731,7 @@ fn iou(b1: &Bbox, b2: &Bbox) -> f32 {
|
||||
i_area / (b1_area + b2_area - i_area)
|
||||
}
|
||||
|
||||
pub fn report(
|
||||
pub fn report_detect(
|
||||
pred: &Tensor,
|
||||
img: DynamicImage,
|
||||
w: usize,
|
||||
@ -651,31 +763,15 @@ pub fn report(
|
||||
xmax: pred[0] + pred[2] / 2.,
|
||||
ymax: pred[1] + pred[3] / 2.,
|
||||
confidence,
|
||||
keypoints: vec![],
|
||||
};
|
||||
bboxes[class_index].push(bbox)
|
||||
}
|
||||
}
|
||||
}
|
||||
// Perform non-maximum suppression.
|
||||
for bboxes_for_class in bboxes.iter_mut() {
|
||||
bboxes_for_class.sort_by(|b1, b2| b2.confidence.partial_cmp(&b1.confidence).unwrap());
|
||||
let mut current_index = 0;
|
||||
for index in 0..bboxes_for_class.len() {
|
||||
let mut drop = false;
|
||||
for prev_index in 0..current_index {
|
||||
let iou = iou(&bboxes_for_class[prev_index], &bboxes_for_class[index]);
|
||||
if iou > iou_threshold {
|
||||
drop = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if !drop {
|
||||
bboxes_for_class.swap(current_index, index);
|
||||
current_index += 1;
|
||||
}
|
||||
}
|
||||
bboxes_for_class.truncate(current_index);
|
||||
}
|
||||
|
||||
non_maximum_suppression(&mut bboxes, iou_threshold);
|
||||
|
||||
// Annotate the original image and print boxes information.
|
||||
let (initial_h, initial_w) = (img.height() as f32, img.width() as f32);
|
||||
let w_ratio = initial_w / w as f32;
|
||||
@ -691,3 +787,84 @@ pub fn report(
|
||||
}
|
||||
Ok(bboxes)
|
||||
}
|
||||
|
||||
fn non_maximum_suppression(bboxes: &mut [Vec<Bbox>], threshold: f32) {
|
||||
// Perform non-maximum suppression.
|
||||
for bboxes_for_class in bboxes.iter_mut() {
|
||||
bboxes_for_class.sort_by(|b1, b2| b2.confidence.partial_cmp(&b1.confidence).unwrap());
|
||||
let mut current_index = 0;
|
||||
for index in 0..bboxes_for_class.len() {
|
||||
let mut drop = false;
|
||||
for prev_index in 0..current_index {
|
||||
let iou = iou(&bboxes_for_class[prev_index], &bboxes_for_class[index]);
|
||||
if iou > threshold {
|
||||
drop = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if !drop {
|
||||
bboxes_for_class.swap(current_index, index);
|
||||
current_index += 1;
|
||||
}
|
||||
}
|
||||
bboxes_for_class.truncate(current_index);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn report_pose(
|
||||
pred: &Tensor,
|
||||
img: DynamicImage,
|
||||
w: usize,
|
||||
h: usize,
|
||||
confidence_threshold: f32,
|
||||
nms_threshold: f32,
|
||||
) -> Result<Vec<Bbox>> {
|
||||
let (pred_size, npreds) = pred.dims2()?;
|
||||
if pred_size != 17 * 3 + 4 + 1 {
|
||||
candle::bail!("unexpected pred-size {pred_size}");
|
||||
}
|
||||
let mut bboxes = vec![];
|
||||
// Extract the bounding boxes for which confidence is above the threshold.
|
||||
for index in 0..npreds {
|
||||
let pred = Vec::<f32>::try_from(pred.i((.., index))?)?;
|
||||
let confidence = pred[4];
|
||||
if confidence > confidence_threshold {
|
||||
let keypoints = (0..17)
|
||||
.map(|i| KeyPoint {
|
||||
x: pred[3 * i + 5],
|
||||
y: pred[3 * i + 6],
|
||||
mask: pred[3 * i + 7],
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let bbox = Bbox {
|
||||
xmin: pred[0] - pred[2] / 2.,
|
||||
ymin: pred[1] - pred[3] / 2.,
|
||||
xmax: pred[0] + pred[2] / 2.,
|
||||
ymax: pred[1] + pred[3] / 2.,
|
||||
confidence,
|
||||
keypoints,
|
||||
};
|
||||
bboxes.push(bbox)
|
||||
}
|
||||
}
|
||||
|
||||
let mut bboxes = vec![bboxes];
|
||||
non_maximum_suppression(&mut bboxes, nms_threshold);
|
||||
let mut bboxes = bboxes.into_iter().next().unwrap();
|
||||
|
||||
let (initial_h, initial_w) = (img.height() as f32, img.width() as f32);
|
||||
let w_ratio = initial_w / w as f32;
|
||||
let h_ratio = initial_h / h as f32;
|
||||
for b in bboxes.iter_mut() {
|
||||
crate::console_log!("detected {b:?}");
|
||||
b.xmin = (b.xmin * w_ratio).clamp(0., initial_w - 1.);
|
||||
b.ymin = (b.ymin * h_ratio).clamp(0., initial_h - 1.);
|
||||
b.xmax = (b.xmax * w_ratio).clamp(0., initial_w - 1.);
|
||||
b.ymax = (b.ymax * h_ratio).clamp(0., initial_h - 1.);
|
||||
for kp in b.keypoints.iter_mut() {
|
||||
kp.x = (kp.x * w_ratio).clamp(0., initial_w - 1.);
|
||||
kp.y = (kp.y * h_ratio).clamp(0., initial_h - 1.);
|
||||
}
|
||||
}
|
||||
Ok(bboxes)
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
use crate::model::{report, Bbox, Multiples, YoloV8};
|
||||
use crate::model::{report_detect, report_pose, Bbox, Multiples, YoloV8, YoloV8Pose};
|
||||
use candle::{DType, Device, Result, Tensor};
|
||||
use candle_nn::{Module, VarBuilder};
|
||||
use serde::{Deserialize, Serialize};
|
||||
@ -81,7 +81,7 @@ impl Model {
|
||||
let image_t = (image_t.unsqueeze(0)?.to_dtype(DType::F32)? * (1. / 255.))?;
|
||||
let predictions = self.model.forward(&image_t)?.squeeze(0)?;
|
||||
console_log!("generated predictions {predictions:?}");
|
||||
let bboxes = report(
|
||||
let bboxes = report_detect(
|
||||
&predictions,
|
||||
original_image,
|
||||
width,
|
||||
@ -115,6 +115,86 @@ impl Model {
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ModelPose {
|
||||
model: YoloV8Pose,
|
||||
}
|
||||
|
||||
impl ModelPose {
|
||||
pub fn run(
|
||||
&self,
|
||||
image_data: Vec<u8>,
|
||||
conf_threshold: f32,
|
||||
iou_threshold: f32,
|
||||
) -> Result<Vec<Bbox>> {
|
||||
console_log!("image data: {}", image_data.len());
|
||||
let image_data = std::io::Cursor::new(image_data);
|
||||
let original_image = image::io::Reader::new(image_data)
|
||||
.with_guessed_format()?
|
||||
.decode()
|
||||
.map_err(candle::Error::wrap)?;
|
||||
let (width, height) = {
|
||||
let w = original_image.width() as usize;
|
||||
let h = original_image.height() as usize;
|
||||
if w < h {
|
||||
let w = w * 640 / h;
|
||||
// Sizes have to be divisible by 32.
|
||||
(w / 32 * 32, 640)
|
||||
} else {
|
||||
let h = h * 640 / w;
|
||||
(640, h / 32 * 32)
|
||||
}
|
||||
};
|
||||
let image_t = {
|
||||
let img = original_image.resize_exact(
|
||||
width as u32,
|
||||
height as u32,
|
||||
image::imageops::FilterType::CatmullRom,
|
||||
);
|
||||
let data = img.to_rgb8().into_raw();
|
||||
Tensor::from_vec(
|
||||
data,
|
||||
(img.height() as usize, img.width() as usize, 3),
|
||||
&Device::Cpu,
|
||||
)?
|
||||
.permute((2, 0, 1))?
|
||||
};
|
||||
let image_t = (image_t.unsqueeze(0)?.to_dtype(DType::F32)? * (1. / 255.))?;
|
||||
let predictions = self.model.forward(&image_t)?.squeeze(0)?;
|
||||
console_log!("generated predictions {predictions:?}");
|
||||
let bboxes = report_pose(
|
||||
&predictions,
|
||||
original_image,
|
||||
width,
|
||||
height,
|
||||
conf_threshold,
|
||||
iou_threshold,
|
||||
)?;
|
||||
Ok(bboxes)
|
||||
}
|
||||
|
||||
pub fn load_(weights: &[u8], model_size: &str) -> Result<Self> {
|
||||
let multiples = match model_size {
|
||||
"n" => Multiples::n(),
|
||||
"s" => Multiples::s(),
|
||||
"m" => Multiples::m(),
|
||||
"l" => Multiples::l(),
|
||||
"x" => Multiples::x(),
|
||||
_ => Err(candle::Error::Msg(
|
||||
"invalid model size: must be n, s, m, l or x".to_string(),
|
||||
))?,
|
||||
};
|
||||
let dev = &Device::Cpu;
|
||||
let weights = safetensors::tensor::SafeTensors::deserialize(weights)?;
|
||||
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, dev);
|
||||
let model = YoloV8Pose::load(vb, multiples, 1, (17, 3))?;
|
||||
Ok(Self { model })
|
||||
}
|
||||
|
||||
pub fn load(md: ModelData) -> Result<Self> {
|
||||
Self::load_(&md.weights, &md.model_size.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Worker {
|
||||
link: WorkerLink<Self>,
|
||||
model: Option<Model>,
|
||||
|
Reference in New Issue
Block a user