Sketch the yolo wasm example. (#546)

* Sketch the yolo wasm example.

* Web ui.

* Get the web ui to work.

* UI tweaks.

* More UI tweaks.

* Use the natural width/height.

* Add a link to the hf space in the readme.
This commit is contained in:
Laurent Mazare
2023-08-22 11:56:43 +01:00
committed by GitHub
parent 44420d8ae1
commit 20ce3e9f39
13 changed files with 1260 additions and 5 deletions

1
.gitignore vendored
View File

@ -26,6 +26,7 @@ flamegraph.svg
trace-*.json trace-*.json
candle-wasm-examples/*/*.bin candle-wasm-examples/*/*.bin
candle-wasm-examples/*/*.jpeg
candle-wasm-examples/*/*.wav candle-wasm-examples/*/*.wav
candle-wasm-examples/*/*.safetensors candle-wasm-examples/*/*.safetensors
candle-wasm-examples/*/package-lock.json candle-wasm-examples/*/package-lock.json

View File

@ -8,6 +8,7 @@ members = [
"candle-transformers", "candle-transformers",
"candle-wasm-examples/llama2-c", "candle-wasm-examples/llama2-c",
"candle-wasm-examples/whisper", "candle-wasm-examples/whisper",
"candle-wasm-examples/yolo",
] ]
exclude = [ exclude = [
"candle-flash-attn", "candle-flash-attn",

View File

@ -7,7 +7,8 @@
Candle is a minimalist ML framework for Rust with a focus on performance (including GPU support) Candle is a minimalist ML framework for Rust with a focus on performance (including GPU support)
and ease of use. Try our online demos: and ease of use. Try our online demos:
[whisper](https://huggingface.co/spaces/lmz/candle-whisper), [whisper](https://huggingface.co/spaces/lmz/candle-whisper),
[LLaMA2](https://huggingface.co/spaces/lmz/candle-llama2). [LLaMA2](https://huggingface.co/spaces/lmz/candle-llama2),
[yolo](https://huggingface.co/spaces/lmz/candle-yolo).
```rust ```rust
let a = Tensor::randn(0f32, 1., (2, 3), &Device::Cpu)?; let a = Tensor::randn(0f32, 1., (2, 3), &Device::Cpu)?;

View File

@ -1,5 +1,3 @@
#![allow(dead_code)]
#[cfg(feature = "mkl")] #[cfg(feature = "mkl")]
extern crate intel_mkl_src; extern crate intel_mkl_src;
@ -156,7 +154,6 @@ struct C2f {
cv1: ConvBlock, cv1: ConvBlock,
cv2: ConvBlock, cv2: ConvBlock,
bottleneck: Vec<Bottleneck>, bottleneck: Vec<Bottleneck>,
c: usize,
} }
impl C2f { impl C2f {
@ -173,7 +170,6 @@ impl C2f {
cv1, cv1,
cv2, cv2,
bottleneck, bottleneck,
c,
}) })
} }
} }

View File

@ -0,0 +1,57 @@
[package]
name = "candle-wasm-example-yolo"
version.workspace = true
edition.workspace = true
description.workspace = true
repository.workspace = true
keywords.workspace = true
categories.workspace = true
license.workspace = true
[dependencies]
candle = { path = "../../candle-core", version = "0.1.2", package = "candle-core" }
candle-nn = { path = "../../candle-nn", version = "0.1.2" }
num-traits = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
image = { workspace = true }
# App crates.
anyhow = { workspace = true }
byteorder = { workspace = true }
log = { workspace = true }
rand = { workspace = true }
safetensors = { workspace = true }
# Wasm specific crates.
console_error_panic_hook = "0.1.7"
getrandom = { version = "0.2", features = ["js"] }
gloo = "0.8"
js-sys = "0.3.64"
wasm-bindgen = "0.2.87"
wasm-bindgen-futures = "0.4.37"
wasm-logger = "0.2"
yew-agent = "0.2.0"
yew = { version = "0.20.0", features = ["csr"] }
[dependencies.web-sys]
version = "0.3.64"
features = [
'Blob',
'CanvasRenderingContext2d',
'Document',
'Element',
'HtmlElement',
'HtmlCanvasElement',
'HtmlImageElement',
'ImageData',
'Node',
'Window',
'Request',
'RequestCache',
'RequestInit',
'RequestMode',
'Response',
'Performance',
'TextMetrics',
]

View File

@ -0,0 +1,17 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8" />
<title>Welcome to Candle!</title>
<link data-trunk rel="copy-file" href="yolo.safetensors" />
<link data-trunk rel="copy-file" href="bike.jpeg" />
<link data-trunk rel="rust" href="Cargo.toml" data-bin="app" data-type="main" />
<link data-trunk rel="rust" href="Cargo.toml" data-bin="worker" data-type="worker" />
<link rel="stylesheet" href="https://fonts.googleapis.com/css?family=Roboto:300,300italic,700,700italic">
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/normalize/8.0.1/normalize.css">
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/milligram/1.4.1/milligram.css">
</head>
<body></body>
</html>

View File

@ -0,0 +1,268 @@
use crate::console_log;
use crate::worker::{ModelData, Worker, WorkerInput, WorkerOutput};
use wasm_bindgen::prelude::*;
use wasm_bindgen_futures::JsFuture;
use yew::{html, Component, Context, Html};
use yew_agent::{Bridge, Bridged};
async fn fetch_url(url: &str) -> Result<Vec<u8>, JsValue> {
use web_sys::{Request, RequestCache, RequestInit, RequestMode, Response};
let window = web_sys::window().ok_or("window")?;
let mut opts = RequestInit::new();
let opts = opts
.method("GET")
.mode(RequestMode::Cors)
.cache(RequestCache::NoCache);
let request = Request::new_with_str_and_init(url, opts)?;
let resp_value = JsFuture::from(window.fetch_with_request(&request)).await?;
// `resp_value` is a `Response` object.
assert!(resp_value.is_instance_of::<Response>());
let resp: Response = resp_value.dyn_into()?;
let data = JsFuture::from(resp.blob()?).await?;
let blob = web_sys::Blob::from(data);
let array_buffer = JsFuture::from(blob.array_buffer()).await?;
let data = js_sys::Uint8Array::new(&array_buffer).to_vec();
Ok(data)
}
pub enum Msg {
Refresh,
Run,
UpdateStatus(String),
SetModel(ModelData),
WorkerInMsg(WorkerInput),
WorkerOutMsg(Result<WorkerOutput, String>),
}
pub struct CurrentDecode {
start_time: Option<f64>,
}
pub struct App {
status: String,
loaded: bool,
generated: String,
current_decode: Option<CurrentDecode>,
worker: Box<dyn Bridge<Worker>>,
}
async fn model_data_load() -> Result<ModelData, JsValue> {
let weights = fetch_url("yolo.safetensors").await?;
console_log!("loaded weights {}", weights.len());
Ok(ModelData { weights })
}
fn performance_now() -> Option<f64> {
let window = web_sys::window()?;
let performance = window.performance()?;
Some(performance.now() / 1000.)
}
fn draw_bboxes(bboxes: Vec<Vec<crate::model::Bbox>>) -> Result<(), JsValue> {
let document = web_sys::window().unwrap().document().unwrap();
let canvas = match document.get_element_by_id("canvas") {
Some(canvas) => canvas,
None => return Err("no canvas".into()),
};
let canvas: web_sys::HtmlCanvasElement = canvas.dyn_into::<web_sys::HtmlCanvasElement>()?;
let context = canvas
.get_context("2d")?
.ok_or("no 2d")?
.dyn_into::<web_sys::CanvasRenderingContext2d>()?;
let image_html_element = document.get_element_by_id("bike-img");
let image_html_element = match image_html_element {
Some(data) => data,
None => return Err("no bike-img".into()),
};
let image_html_element = image_html_element.dyn_into::<web_sys::HtmlImageElement>()?;
canvas.set_width(image_html_element.natural_width());
canvas.set_height(image_html_element.natural_height());
context.draw_image_with_html_image_element(&image_html_element, 0., 0.)?;
context.set_stroke_style(&JsValue::from("#0dff9a"));
for (class_index, bboxes_for_class) in bboxes.iter().enumerate() {
for b in bboxes_for_class.iter() {
let name = crate::coco_classes::NAMES[class_index];
context.stroke_rect(
b.xmin as f64,
b.ymin as f64,
(b.xmax - b.xmin) as f64,
(b.ymax - b.ymin) as f64,
);
if let Ok(metrics) = context.measure_text(name) {
let width = metrics.width();
context.set_fill_style(&"#3c8566".into());
context.fill_rect(b.xmin as f64 - 2., b.ymin as f64 - 12., width + 4., 14.);
context.set_fill_style(&"#e3fff3".into());
context.fill_text(name, b.xmin as f64, b.ymin as f64 - 2.)?
}
}
}
Ok(())
}
impl Component for App {
type Message = Msg;
type Properties = ();
fn create(ctx: &Context<Self>) -> Self {
let status = "loading weights".to_string();
let cb = {
let link = ctx.link().clone();
move |e| link.send_message(Self::Message::WorkerOutMsg(e))
};
let worker = Worker::bridge(std::rc::Rc::new(cb));
Self {
status,
generated: String::new(),
current_decode: None,
worker,
loaded: false,
}
}
fn rendered(&mut self, ctx: &Context<Self>, first_render: bool) {
if first_render {
ctx.link().send_future(async {
match model_data_load().await {
Err(err) => {
let status = format!("{err:?}");
Msg::UpdateStatus(status)
}
Ok(model_data) => Msg::SetModel(model_data),
}
});
}
}
fn update(&mut self, ctx: &Context<Self>, msg: Self::Message) -> bool {
match msg {
Msg::SetModel(md) => {
self.status = "weights loaded succesfully!".to_string();
self.loaded = true;
console_log!("loaded weights");
self.worker.send(WorkerInput::ModelData(md));
true
}
Msg::Run => {
if self.current_decode.is_some() {
self.status = "already processing some image at the moment".to_string()
} else {
let start_time = performance_now();
self.current_decode = Some(CurrentDecode { start_time });
self.status = "processing...".to_string();
self.generated.clear();
ctx.link().send_future(async {
match fetch_url("bike.jpeg").await {
Err(err) => {
let status = format!("{err:?}");
Msg::UpdateStatus(status)
}
Ok(image_data) => Msg::WorkerInMsg(WorkerInput::Run(image_data)),
}
});
}
true
}
Msg::WorkerOutMsg(output) => {
match output {
Ok(WorkerOutput::WeightsLoaded) => self.status = "weights loaded!".to_string(),
Ok(WorkerOutput::ProcessingDone(Err(err))) => {
self.status = format!("error in worker process: {err}");
self.current_decode = None
}
Ok(WorkerOutput::ProcessingDone(Ok(bboxes))) => {
let mut content = Vec::new();
for (class_index, bboxes_for_class) in bboxes.iter().enumerate() {
for b in bboxes_for_class.iter() {
content.push(format!(
"bbox {}: xs {:.0}-{:.0} ys {:.0}-{:.0}",
crate::coco_classes::NAMES[class_index],
b.xmin,
b.xmax,
b.ymin,
b.ymax
))
}
}
self.generated = content.join("\n");
let dt = self.current_decode.as_ref().and_then(|current_decode| {
current_decode.start_time.and_then(|start_time| {
performance_now().map(|stop_time| stop_time - start_time)
})
});
self.status = match dt {
None => "processing succeeded!".to_string(),
Some(dt) => format!("processing succeeded in {:.2}s", dt,),
};
self.current_decode = None;
if let Err(err) = draw_bboxes(bboxes) {
self.status = format!("{err:?}")
}
}
Err(err) => {
self.status = format!("error in worker {err:?}");
}
}
true
}
Msg::WorkerInMsg(inp) => {
self.worker.send(inp);
true
}
Msg::UpdateStatus(status) => {
self.status = status;
true
}
Msg::Refresh => true,
}
}
fn view(&self, ctx: &Context<Self>) -> Html {
html! {
<div style="margin: 2%;">
<div><p>{"Running an object detection model in the browser using rust/wasm with "}
<a href="https://github.com/huggingface/candle" target="_blank">{"candle!"}</a>
</p>
<p>{"Once the weights have loaded, click on the run button to process an image."}</p>
<p><img id="bike-img" src="bike.jpeg"/></p>
<p>{"Source: "}<a href="https://commons.wikimedia.org/wiki/File:V%C3%A9lo_parade_-_V%C3%A9lorution_-_bike_critical_mass.JPG">{"wikimedia"}</a></p>
</div>
{
if self.loaded{
html!(<button class="button" onclick={ctx.link().callback(move |_| Msg::Run)}> { "run" }</button>)
}else{
html! { <progress id="progress-bar" aria-label="Loading weights..."></progress> }
}
}
<br/ >
<h3>
{&self.status}
</h3>
{
if self.current_decode.is_some() {
html! { <progress id="progress-bar" aria-label="generating…"></progress> }
} else {
html! {}
}
}
<div>
<canvas id="canvas" height="150" width="150"></canvas>
</div>
<blockquote>
<p> { self.generated.chars().map(|c|
if c == '\r' || c == '\n' {
html! { <br/> }
} else {
html! { {c} }
}).collect::<Html>()
} </p>
</blockquote>
</div>
}
}
}

View File

@ -0,0 +1,5 @@
fn main() {
wasm_logger::init(wasm_logger::Config::new(log::Level::Trace));
console_error_panic_hook::set_once();
yew::Renderer::<candle_wasm_example_yolo::App>::new().render();
}

View File

@ -0,0 +1,5 @@
use yew_agent::PublicWorker;
fn main() {
console_error_panic_hook::set_once();
candle_wasm_example_yolo::Worker::register();
}

View File

@ -0,0 +1,82 @@
pub const NAMES: [&str; 80] = [
"person",
"bicycle",
"car",
"motorbike",
"aeroplane",
"bus",
"train",
"truck",
"boat",
"traffic light",
"fire hydrant",
"stop sign",
"parking meter",
"bench",
"bird",
"cat",
"dog",
"horse",
"sheep",
"cow",
"elephant",
"bear",
"zebra",
"giraffe",
"backpack",
"umbrella",
"handbag",
"tie",
"suitcase",
"frisbee",
"skis",
"snowboard",
"sports ball",
"kite",
"baseball bat",
"baseball glove",
"skateboard",
"surfboard",
"tennis racket",
"bottle",
"wine glass",
"cup",
"fork",
"knife",
"spoon",
"bowl",
"banana",
"apple",
"sandwich",
"orange",
"broccoli",
"carrot",
"hot dog",
"pizza",
"donut",
"cake",
"chair",
"sofa",
"pottedplant",
"bed",
"diningtable",
"toilet",
"tvmonitor",
"laptop",
"mouse",
"remote",
"keyboard",
"cell phone",
"microwave",
"oven",
"toaster",
"sink",
"refrigerator",
"book",
"clock",
"vase",
"scissors",
"teddy bear",
"hair drier",
"toothbrush",
];

View File

@ -0,0 +1,6 @@
mod app;
mod coco_classes;
mod model;
mod worker;
pub use app::App;
pub use worker::Worker;

View File

@ -0,0 +1,684 @@
#![allow(dead_code)]
use candle::{DType, IndexOp, Result, Tensor, D};
use candle_nn::{
batch_norm, conv2d, conv2d_no_bias, BatchNorm, Conv2d, Conv2dConfig, Module, VarBuilder,
};
use image::DynamicImage;
const CONFIDENCE_THRESHOLD: f32 = 0.5;
const NMS_THRESHOLD: f32 = 0.4;
// Model architecture from https://github.com/ultralytics/ultralytics/issues/189
// https://github.com/tinygrad/tinygrad/blob/master/examples/yolov8.py
#[derive(Clone, Copy, PartialEq, Debug)]
pub struct Multiples {
depth: f64,
width: f64,
ratio: f64,
}
impl Multiples {
pub fn n() -> Self {
Self {
depth: 0.33,
width: 0.25,
ratio: 2.0,
}
}
pub fn s() -> Self {
Self {
depth: 0.33,
width: 0.50,
ratio: 2.0,
}
}
pub fn m() -> Self {
Self {
depth: 0.67,
width: 0.75,
ratio: 1.5,
}
}
pub fn l() -> Self {
Self {
depth: 1.00,
width: 1.00,
ratio: 1.0,
}
}
pub fn x() -> Self {
Self {
depth: 1.00,
width: 1.25,
ratio: 1.0,
}
}
fn filters(&self) -> (usize, usize, usize) {
let f1 = (256. * self.width) as usize;
let f2 = (512. * self.width) as usize;
let f3 = (512. * self.width * self.ratio) as usize;
(f1, f2, f3)
}
}
#[derive(Debug)]
struct Upsample {
scale_factor: usize,
}
impl Upsample {
fn new(scale_factor: usize) -> Result<Self> {
Ok(Upsample { scale_factor })
}
}
impl Module for Upsample {
fn forward(&self, xs: &Tensor) -> candle::Result<Tensor> {
let (_b_size, _channels, h, w) = xs.dims4()?;
xs.upsample_nearest2d(self.scale_factor * h, self.scale_factor * w)
}
}
#[derive(Debug)]
struct ConvBlock {
conv: Conv2d,
bn: BatchNorm,
}
impl ConvBlock {
fn load(
vb: VarBuilder,
c1: usize,
c2: usize,
k: usize,
stride: usize,
padding: Option<usize>,
) -> Result<Self> {
let padding = padding.unwrap_or(k / 2);
let cfg = Conv2dConfig { padding, stride };
let conv = conv2d_no_bias(c1, c2, k, cfg, vb.pp("conv"))?;
let bn = batch_norm(c2, 1e-3, vb.pp("bn"))?;
Ok(Self { conv, bn })
}
}
impl Module for ConvBlock {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs = self.conv.forward(xs)?;
let xs = self.bn.forward(&xs)?;
candle_nn::ops::silu(&xs)
}
}
#[derive(Debug)]
struct Bottleneck {
cv1: ConvBlock,
cv2: ConvBlock,
residual: bool,
}
impl Bottleneck {
fn load(vb: VarBuilder, c1: usize, c2: usize, shortcut: bool) -> Result<Self> {
let channel_factor = 1.;
let c_ = (c2 as f64 * channel_factor) as usize;
let cv1 = ConvBlock::load(vb.pp("cv1"), c1, c_, 3, 1, None)?;
let cv2 = ConvBlock::load(vb.pp("cv2"), c_, c2, 3, 1, None)?;
let residual = c1 == c2 && shortcut;
Ok(Self { cv1, cv2, residual })
}
}
impl Module for Bottleneck {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let ys = self.cv2.forward(&self.cv1.forward(xs)?)?;
if self.residual {
xs + ys
} else {
Ok(ys)
}
}
}
#[derive(Debug)]
struct C2f {
cv1: ConvBlock,
cv2: ConvBlock,
bottleneck: Vec<Bottleneck>,
}
impl C2f {
fn load(vb: VarBuilder, c1: usize, c2: usize, n: usize, shortcut: bool) -> Result<Self> {
let c = (c2 as f64 * 0.5) as usize;
let cv1 = ConvBlock::load(vb.pp("cv1"), c1, 2 * c, 1, 1, None)?;
let cv2 = ConvBlock::load(vb.pp("cv2"), (2 + n) * c, c2, 1, 1, None)?;
let mut bottleneck = Vec::with_capacity(n);
for idx in 0..n {
let b = Bottleneck::load(vb.pp(&format!("bottleneck.{idx}")), c, c, shortcut)?;
bottleneck.push(b)
}
Ok(Self {
cv1,
cv2,
bottleneck,
})
}
}
impl Module for C2f {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let ys = self.cv1.forward(xs)?;
let mut ys = ys.chunk(2, 1)?;
for m in self.bottleneck.iter() {
ys.push(m.forward(ys.last().unwrap())?)
}
let zs = Tensor::cat(ys.as_slice(), 1)?;
self.cv2.forward(&zs)
}
}
#[derive(Debug)]
struct Sppf {
cv1: ConvBlock,
cv2: ConvBlock,
k: usize,
}
impl Sppf {
fn load(vb: VarBuilder, c1: usize, c2: usize, k: usize) -> Result<Self> {
let c_ = c1 / 2;
let cv1 = ConvBlock::load(vb.pp("cv1"), c1, c_, 1, 1, None)?;
let cv2 = ConvBlock::load(vb.pp("cv2"), c_ * 4, c2, 1, 1, None)?;
Ok(Self { cv1, cv2, k })
}
}
impl Module for Sppf {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let (_, _, _, _) = xs.dims4()?;
let xs = self.cv1.forward(xs)?;
let xs2 = xs
.pad_with_zeros(2, self.k / 2, self.k / 2)?
.pad_with_zeros(3, self.k / 2, self.k / 2)?
.max_pool2d((self.k, self.k), (1, 1))?;
let xs3 = xs2
.pad_with_zeros(2, self.k / 2, self.k / 2)?
.pad_with_zeros(3, self.k / 2, self.k / 2)?
.max_pool2d((self.k, self.k), (1, 1))?;
let xs4 = xs3
.pad_with_zeros(2, self.k / 2, self.k / 2)?
.pad_with_zeros(3, self.k / 2, self.k / 2)?
.max_pool2d((self.k, self.k), (1, 1))?;
self.cv2.forward(&Tensor::cat(&[&xs, &xs2, &xs3, &xs4], 1)?)
}
}
#[derive(Debug)]
struct Dfl {
conv: Conv2d,
num_classes: usize,
}
impl Dfl {
fn load(vb: VarBuilder, num_classes: usize) -> Result<Self> {
let conv = conv2d_no_bias(num_classes, 1, 1, Default::default(), vb.pp("conv"))?;
Ok(Self { conv, num_classes })
}
}
impl Module for Dfl {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let (b_sz, _channels, anchors) = xs.dims3()?;
let xs = xs
.reshape((b_sz, 4, self.num_classes, anchors))?
.transpose(2, 1)?;
let xs = candle_nn::ops::softmax(&xs, 1)?;
self.conv.forward(&xs)?.reshape((b_sz, 4, anchors))
}
}
#[derive(Debug)]
struct DarkNet {
b1_0: ConvBlock,
b1_1: ConvBlock,
b2_0: C2f,
b2_1: ConvBlock,
b2_2: C2f,
b3_0: ConvBlock,
b3_1: C2f,
b4_0: ConvBlock,
b4_1: C2f,
b5: Sppf,
}
impl DarkNet {
fn load(vb: VarBuilder, m: Multiples) -> Result<Self> {
let (w, r, d) = (m.width, m.ratio, m.depth);
let b1_0 = ConvBlock::load(vb.pp("b1.0"), 3, (64. * w) as usize, 3, 2, Some(1))?;
let b1_1 = ConvBlock::load(
vb.pp("b1.1"),
(64. * w) as usize,
(128. * w) as usize,
3,
2,
Some(1),
)?;
let b2_0 = C2f::load(
vb.pp("b2.0"),
(128. * w) as usize,
(128. * w) as usize,
(3. * d).round() as usize,
true,
)?;
let b2_1 = ConvBlock::load(
vb.pp("b2.1"),
(128. * w) as usize,
(256. * w) as usize,
3,
2,
Some(1),
)?;
let b2_2 = C2f::load(
vb.pp("b2.2"),
(256. * w) as usize,
(256. * w) as usize,
(6. * d).round() as usize,
true,
)?;
let b3_0 = ConvBlock::load(
vb.pp("b3.0"),
(256. * w) as usize,
(512. * w) as usize,
3,
2,
Some(1),
)?;
let b3_1 = C2f::load(
vb.pp("b3.1"),
(512. * w) as usize,
(512. * w) as usize,
(6. * d).round() as usize,
true,
)?;
let b4_0 = ConvBlock::load(
vb.pp("b4.0"),
(512. * w) as usize,
(512. * w * r) as usize,
3,
2,
Some(1),
)?;
let b4_1 = C2f::load(
vb.pp("b4.1"),
(512. * w * r) as usize,
(512. * w * r) as usize,
(3. * d).round() as usize,
true,
)?;
let b5 = Sppf::load(
vb.pp("b5.0"),
(512. * w * r) as usize,
(512. * w * r) as usize,
5,
)?;
Ok(Self {
b1_0,
b1_1,
b2_0,
b2_1,
b2_2,
b3_0,
b3_1,
b4_0,
b4_1,
b5,
})
}
fn forward(&self, xs: &Tensor) -> Result<(Tensor, Tensor, Tensor)> {
let x1 = self.b1_1.forward(&self.b1_0.forward(xs)?)?;
let x2 = self
.b2_2
.forward(&self.b2_1.forward(&self.b2_0.forward(&x1)?)?)?;
let x3 = self.b3_1.forward(&self.b3_0.forward(&x2)?)?;
let x4 = self.b4_1.forward(&self.b4_0.forward(&x3)?)?;
let x5 = self.b5.forward(&x4)?;
Ok((x2, x3, x5))
}
}
#[derive(Debug)]
struct YoloV8Neck {
up: Upsample,
n1: C2f,
n2: C2f,
n3: ConvBlock,
n4: C2f,
n5: ConvBlock,
n6: C2f,
}
impl YoloV8Neck {
fn load(vb: VarBuilder, m: Multiples) -> Result<Self> {
let up = Upsample::new(2)?;
let (w, r, d) = (m.width, m.ratio, m.depth);
let n = (3. * d).round() as usize;
let n1 = C2f::load(
vb.pp("n1"),
(512. * w * (1. + r)) as usize,
(512. * w) as usize,
n,
false,
)?;
let n2 = C2f::load(
vb.pp("n2"),
(768. * w) as usize,
(256. * w) as usize,
n,
false,
)?;
let n3 = ConvBlock::load(
vb.pp("n3"),
(256. * w) as usize,
(256. * w) as usize,
3,
2,
Some(1),
)?;
let n4 = C2f::load(
vb.pp("n4"),
(768. * w) as usize,
(512. * w) as usize,
n,
false,
)?;
let n5 = ConvBlock::load(
vb.pp("n5"),
(512. * w) as usize,
(512. * w) as usize,
3,
2,
Some(1),
)?;
let n6 = C2f::load(
vb.pp("n6"),
(512. * w * (1. + r)) as usize,
(512. * w * r) as usize,
n,
false,
)?;
Ok(Self {
up,
n1,
n2,
n3,
n4,
n5,
n6,
})
}
fn forward(&self, p3: &Tensor, p4: &Tensor, p5: &Tensor) -> Result<(Tensor, Tensor, Tensor)> {
let x = self
.n1
.forward(&Tensor::cat(&[&self.up.forward(p5)?, p4], 1)?)?;
let head_1 = self
.n2
.forward(&Tensor::cat(&[&self.up.forward(&x)?, p3], 1)?)?;
let head_2 = self
.n4
.forward(&Tensor::cat(&[&self.n3.forward(&head_1)?, &x], 1)?)?;
let head_3 = self
.n6
.forward(&Tensor::cat(&[&self.n5.forward(&head_2)?, p5], 1)?)?;
Ok((head_1, head_2, head_3))
}
}
#[derive(Debug)]
struct DetectionHead {
dfl: Dfl,
cv2: [(ConvBlock, ConvBlock, Conv2d); 3],
cv3: [(ConvBlock, ConvBlock, Conv2d); 3],
ch: usize,
no: usize,
}
fn make_anchors(
xs0: &Tensor,
xs1: &Tensor,
xs2: &Tensor,
(s0, s1, s2): (usize, usize, usize),
grid_cell_offset: f64,
) -> Result<(Tensor, Tensor)> {
let dev = xs0.device();
let mut anchor_points = vec![];
let mut stride_tensor = vec![];
for (xs, stride) in [(xs0, s0), (xs1, s1), (xs2, s2)] {
// xs is only used to extract the h and w dimensions.
let (_, _, h, w) = xs.dims4()?;
let sx = (Tensor::arange(0, w as u32, dev)?.to_dtype(DType::F32)? + grid_cell_offset)?;
let sy = (Tensor::arange(0, h as u32, dev)?.to_dtype(DType::F32)? + grid_cell_offset)?;
let sx = sx
.reshape((1, sx.elem_count()))?
.repeat((h, 1))?
.flatten_all()?;
let sy = sy
.reshape((sy.elem_count(), 1))?
.repeat((1, w))?
.flatten_all()?;
anchor_points.push(Tensor::stack(&[&sx, &sy], D::Minus1)?);
stride_tensor.push((Tensor::ones(h * w, DType::F32, dev)? * stride as f64)?);
}
let anchor_points = Tensor::cat(anchor_points.as_slice(), 0)?;
let stride_tensor = Tensor::cat(stride_tensor.as_slice(), 0)?.unsqueeze(1)?;
Ok((anchor_points, stride_tensor))
}
fn dist2bbox(distance: &Tensor, anchor_points: &Tensor) -> Result<Tensor> {
let chunks = distance.chunk(2, 1)?;
let lt = &chunks[0];
let rb = &chunks[1];
let x1y1 = anchor_points.sub(lt)?;
let x2y2 = anchor_points.add(rb)?;
let c_xy = ((&x1y1 + &x2y2)? * 0.5)?;
let wh = (&x2y2 - &x1y1)?;
Tensor::cat(&[c_xy, wh], 1)
}
impl DetectionHead {
fn load(vb: VarBuilder, nc: usize, filters: (usize, usize, usize)) -> Result<Self> {
let ch = 16;
let dfl = Dfl::load(vb.pp("dfl"), ch)?;
let c1 = usize::max(filters.0, nc);
let c2 = usize::max(filters.0 / 4, ch * 4);
let cv3 = [
Self::load_cv3(vb.pp("cv3.0"), c1, nc, filters.0)?,
Self::load_cv3(vb.pp("cv3.1"), c1, nc, filters.1)?,
Self::load_cv3(vb.pp("cv3.2"), c1, nc, filters.2)?,
];
let cv2 = [
Self::load_cv2(vb.pp("cv2.0"), c2, ch, filters.0)?,
Self::load_cv2(vb.pp("cv2.1"), c2, ch, filters.1)?,
Self::load_cv2(vb.pp("cv2.2"), c2, ch, filters.2)?,
];
let no = nc + ch * 4;
Ok(Self {
dfl,
cv2,
cv3,
ch,
no,
})
}
fn load_cv3(
vb: VarBuilder,
c1: usize,
nc: usize,
filter: usize,
) -> Result<(ConvBlock, ConvBlock, Conv2d)> {
let block0 = ConvBlock::load(vb.pp("0"), filter, c1, 3, 1, None)?;
let block1 = ConvBlock::load(vb.pp("1"), c1, c1, 3, 1, None)?;
let conv = conv2d(c1, nc, 1, Default::default(), vb.pp("2"))?;
Ok((block0, block1, conv))
}
fn load_cv2(
vb: VarBuilder,
c2: usize,
ch: usize,
filter: usize,
) -> Result<(ConvBlock, ConvBlock, Conv2d)> {
let block0 = ConvBlock::load(vb.pp("0"), filter, c2, 3, 1, None)?;
let block1 = ConvBlock::load(vb.pp("1"), c2, c2, 3, 1, None)?;
let conv = conv2d(c2, 4 * ch, 1, Default::default(), vb.pp("2"))?;
Ok((block0, block1, conv))
}
fn forward(&self, xs0: &Tensor, xs1: &Tensor, xs2: &Tensor) -> Result<Tensor> {
let forward_cv = |xs, i: usize| {
let xs_2 = self.cv2[i].0.forward(xs)?;
let xs_2 = self.cv2[i].1.forward(&xs_2)?;
let xs_2 = self.cv2[i].2.forward(&xs_2)?;
let xs_3 = self.cv3[i].0.forward(xs)?;
let xs_3 = self.cv3[i].1.forward(&xs_3)?;
let xs_3 = self.cv3[i].2.forward(&xs_3)?;
Tensor::cat(&[&xs_2, &xs_3], 1)
};
let xs0 = forward_cv(xs0, 0)?;
let xs1 = forward_cv(xs1, 1)?;
let xs2 = forward_cv(xs2, 2)?;
let (anchors, strides) = make_anchors(&xs0, &xs1, &xs2, (8, 16, 32), 0.5)?;
let anchors = anchors.transpose(0, 1)?;
let strides = strides.transpose(0, 1)?;
let reshape = |xs: &Tensor| {
let d = xs.dim(0)?;
let el = xs.elem_count();
xs.reshape((d, self.no, el / (d * self.no)))
};
let ys0 = reshape(&xs0)?;
let ys1 = reshape(&xs1)?;
let ys2 = reshape(&xs2)?;
let x_cat = Tensor::cat(&[ys0, ys1, ys2], 2)?;
let box_ = x_cat.i((.., ..self.ch * 4))?;
let cls = x_cat.i((.., self.ch * 4..))?;
let dbox = dist2bbox(&self.dfl.forward(&box_)?, &anchors.unsqueeze(0)?)?;
let dbox = dbox.broadcast_mul(&strides)?;
Tensor::cat(&[dbox, candle_nn::ops::sigmoid(&cls)?], 1)
}
}
#[derive(Debug)]
pub struct YoloV8 {
net: DarkNet,
fpn: YoloV8Neck,
head: DetectionHead,
}
impl YoloV8 {
pub fn load(vb: VarBuilder, m: Multiples, num_classes: usize) -> Result<Self> {
let net = DarkNet::load(vb.pp("net"), m)?;
let fpn = YoloV8Neck::load(vb.pp("fpn"), m)?;
let head = DetectionHead::load(vb.pp("head"), num_classes, m.filters())?;
Ok(Self { net, fpn, head })
}
}
impl Module for YoloV8 {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let (xs1, xs2, xs3) = self.net.forward(xs)?;
let (xs1, xs2, xs3) = self.fpn.forward(&xs1, &xs2, &xs3)?;
self.head.forward(&xs1, &xs2, &xs3)
}
}
#[derive(Debug, Clone, Copy, serde::Serialize, serde::Deserialize)]
pub struct Bbox {
pub xmin: f32,
pub ymin: f32,
pub xmax: f32,
pub ymax: f32,
pub confidence: f32,
}
// Intersection over union of two bounding boxes.
fn iou(b1: &Bbox, b2: &Bbox) -> f32 {
let b1_area = (b1.xmax - b1.xmin + 1.) * (b1.ymax - b1.ymin + 1.);
let b2_area = (b2.xmax - b2.xmin + 1.) * (b2.ymax - b2.ymin + 1.);
let i_xmin = b1.xmin.max(b2.xmin);
let i_xmax = b1.xmax.min(b2.xmax);
let i_ymin = b1.ymin.max(b2.ymin);
let i_ymax = b1.ymax.min(b2.ymax);
let i_area = (i_xmax - i_xmin + 1.).max(0.) * (i_ymax - i_ymin + 1.).max(0.);
i_area / (b1_area + b2_area - i_area)
}
pub fn report(pred: &Tensor, img: DynamicImage, w: usize, h: usize) -> Result<Vec<Vec<Bbox>>> {
let (pred_size, npreds) = pred.dims2()?;
let nclasses = pred_size - 4;
// The bounding boxes grouped by (maximum) class index.
let mut bboxes: Vec<Vec<Bbox>> = (0..nclasses).map(|_| vec![]).collect();
// Extract the bounding boxes for which confidence is above the threshold.
for index in 0..npreds {
let pred = Vec::<f32>::try_from(pred.i((.., index))?)?;
let confidence = *pred[4..].iter().max_by(|x, y| x.total_cmp(y)).unwrap();
if confidence > CONFIDENCE_THRESHOLD {
let mut class_index = 0;
for i in 0..nclasses {
if pred[4 + i] > pred[4 + class_index] {
class_index = i
}
}
if pred[class_index + 4] > 0. {
let bbox = Bbox {
xmin: pred[0] - pred[2] / 2.,
ymin: pred[1] - pred[3] / 2.,
xmax: pred[0] + pred[2] / 2.,
ymax: pred[1] + pred[3] / 2.,
confidence,
};
bboxes[class_index].push(bbox)
}
}
}
// Perform non-maximum suppression.
for bboxes_for_class in bboxes.iter_mut() {
bboxes_for_class.sort_by(|b1, b2| b2.confidence.partial_cmp(&b1.confidence).unwrap());
let mut current_index = 0;
for index in 0..bboxes_for_class.len() {
let mut drop = false;
for prev_index in 0..current_index {
let iou = iou(&bboxes_for_class[prev_index], &bboxes_for_class[index]);
if iou > NMS_THRESHOLD {
drop = true;
break;
}
}
if !drop {
bboxes_for_class.swap(current_index, index);
current_index += 1;
}
}
bboxes_for_class.truncate(current_index);
}
// Annotate the original image and print boxes information.
let (initial_h, initial_w) = (img.height() as f32, img.width() as f32);
let w_ratio = initial_w / w as f32;
let h_ratio = initial_h / h as f32;
for (class_index, bboxes_for_class) in bboxes.iter_mut().enumerate() {
for b in bboxes_for_class.iter_mut() {
crate::console_log!("{}: {:?}", crate::coco_classes::NAMES[class_index], b);
b.xmin = (b.xmin * w_ratio).clamp(0., initial_w - 1.);
b.ymin = (b.ymin * h_ratio).clamp(0., initial_h - 1.);
b.xmax = (b.xmax * w_ratio).clamp(0., initial_w - 1.);
b.ymax = (b.ymax * h_ratio).clamp(0., initial_h - 1.);
}
}
Ok(bboxes)
}

View File

@ -0,0 +1,132 @@
use crate::model::{report, Bbox, Multiples, YoloV8};
use candle::{DType, Device, Result, Tensor};
use candle_nn::{Module, VarBuilder};
use serde::{Deserialize, Serialize};
use wasm_bindgen::prelude::*;
use yew_agent::{HandlerId, Public, WorkerLink};
#[wasm_bindgen]
extern "C" {
// Use `js_namespace` here to bind `console.log(..)` instead of just
// `log(..)`
#[wasm_bindgen(js_namespace = console)]
pub fn log(s: &str);
}
#[macro_export]
macro_rules! console_log {
// Note that this is using the `log` function imported above during
// `bare_bones`
($($t:tt)*) => ($crate::worker::log(&format_args!($($t)*).to_string()))
}
// Communication to the worker happens through bincode, the model weights and configs are fetched
// on the main thread and transfered via the following structure.
#[derive(Serialize, Deserialize)]
pub struct ModelData {
pub weights: Vec<u8>,
}
struct Model {
model: YoloV8,
}
impl Model {
fn run(
&self,
_link: &WorkerLink<Worker>,
_id: HandlerId,
image_data: Vec<u8>,
) -> Result<Vec<Vec<Bbox>>> {
console_log!("image data: {}", image_data.len());
let image_data = std::io::Cursor::new(image_data);
let original_image = image::io::Reader::new(image_data)
.with_guessed_format()?
.decode()
.map_err(candle::Error::wrap)?;
let image = {
let data = original_image
.resize_exact(640, 640, image::imageops::FilterType::Triangle)
.to_rgb8()
.into_raw();
Tensor::from_vec(data, (640, 640, 3), &Device::Cpu)?.permute((2, 0, 1))?
};
let image = (image.unsqueeze(0)?.to_dtype(DType::F32)? * (1. / 255.))?;
let predictions = self.model.forward(&image)?.squeeze(0)?;
console_log!("generated predictions {predictions:?}");
let bboxes = report(&predictions, original_image, 640, 640)?;
Ok(bboxes)
}
}
impl Model {
fn load(md: ModelData) -> Result<Self> {
let dev = &Device::Cpu;
let weights = safetensors::tensor::SafeTensors::deserialize(&md.weights)?;
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, dev);
let model = YoloV8::load(vb, Multiples::s(), 80)?;
Ok(Self { model })
}
}
pub struct Worker {
link: WorkerLink<Self>,
model: Option<Model>,
}
#[derive(Serialize, Deserialize)]
pub enum WorkerInput {
ModelData(ModelData),
Run(Vec<u8>),
}
#[derive(Serialize, Deserialize)]
pub enum WorkerOutput {
ProcessingDone(std::result::Result<Vec<Vec<Bbox>>, String>),
WeightsLoaded,
}
impl yew_agent::Worker for Worker {
type Input = WorkerInput;
type Message = ();
type Output = std::result::Result<WorkerOutput, String>;
type Reach = Public<Self>;
fn create(link: WorkerLink<Self>) -> Self {
Self { link, model: None }
}
fn update(&mut self, _msg: Self::Message) {
// no messaging
}
fn handle_input(&mut self, msg: Self::Input, id: HandlerId) {
let output = match msg {
WorkerInput::ModelData(md) => match Model::load(md) {
Ok(model) => {
self.model = Some(model);
Ok(WorkerOutput::WeightsLoaded)
}
Err(err) => Err(format!("model creation error {err:?}")),
},
WorkerInput::Run(image_data) => match &mut self.model {
None => Err("model has not been set yet".to_string()),
Some(model) => {
let result = model
.run(&self.link, id, image_data)
.map_err(|e| e.to_string());
Ok(WorkerOutput::ProcessingDone(result))
}
},
};
self.link.respond(id, output);
}
fn name_of_resource() -> &'static str {
"worker.js"
}
fn resource_path_is_relative() -> bool {
true
}
}