mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 02:16:37 +00:00
[WIP] Improve Yolo WASM UI example (#591)
* return detections with classes names * ignore .DS_Store * example how to load wasm module * add param to set model size * add param for model size * accept iou and confidence threshold on run * conf and iou thresholds * clamp only * remove images from branch * a couple of renamings, add readme with instructions * final design * minor font + border update
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@ -31,3 +31,5 @@ candle-wasm-examples/*/*.jpeg
|
||||
candle-wasm-examples/*/*.wav
|
||||
candle-wasm-examples/*/*.safetensors
|
||||
candle-wasm-examples/*/package-lock.json
|
||||
|
||||
.DS_Store
|
44
candle-wasm-examples/yolo/README.md
Normal file
44
candle-wasm-examples/yolo/README.md
Normal file
@ -0,0 +1,44 @@
|
||||
## Running Yolo Examples
|
||||
|
||||
Here, we provide two examples of how to run YOLOv8 using a Candle-compiled WASM binary and runtimes.
|
||||
|
||||
### Pure Rust UI
|
||||
|
||||
To build and test the UI made in Rust you will need [Trunk](https://trunkrs.dev/#install)
|
||||
From the `candle-wasm-examples/yolo` directory run:
|
||||
|
||||
Download assets:
|
||||
|
||||
```bash
|
||||
wget -c https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/candle/examples/bike.jpeg
|
||||
wget -c https://huggingface.co/lmz/candle-yolo-v8/resolve/main/yolov8s.safetensors
|
||||
```
|
||||
|
||||
Run hot reload server:
|
||||
|
||||
```bash
|
||||
trunk serve --release --public-url / --port 8080
|
||||
```
|
||||
|
||||
### Vanilla JS and WebWorkers
|
||||
|
||||
To build and test the UI made in Vanilla JS and WebWorkers, first we need to build the WASM library:
|
||||
|
||||
```bash
|
||||
sh build-lib.sh
|
||||
```
|
||||
|
||||
This will bundle the library under `./build` and we can import it inside our WebWorker like a normal JS module:
|
||||
|
||||
```js
|
||||
import init, { Model } from "./build/m.js";
|
||||
```
|
||||
|
||||
The full example can be found under `./lib-example.html`. All needed assets are fetched from the web, so no need to download anything.
|
||||
Finally, you can preview the example by running a local HTTP server. For example:
|
||||
|
||||
```bash
|
||||
python -m http.server
|
||||
```
|
||||
|
||||
Then open `http://localhost:8000/lib-example.html` in your browser.
|
@ -4,7 +4,7 @@
|
||||
<meta charset="utf-8" />
|
||||
<title>Welcome to Candle!</title>
|
||||
|
||||
<link data-trunk rel="copy-file" href="yolo.safetensors" />
|
||||
<link data-trunk rel="copy-file" href="yolov8s.safetensors" />
|
||||
<link data-trunk rel="copy-file" href="bike.jpeg" />
|
||||
<link data-trunk rel="rust" href="Cargo.toml" data-bin="app" data-type="main" />
|
||||
<link data-trunk rel="rust" href="Cargo.toml" data-bin="worker" data-type="worker" />
|
||||
|
414
candle-wasm-examples/yolo/lib-example.html
Normal file
414
candle-wasm-examples/yolo/lib-example.html
Normal file
@ -0,0 +1,414 @@
|
||||
<html>
|
||||
<head>
|
||||
<meta content="text/html;charset=utf-8" http-equiv="Content-Type" />
|
||||
<title>Candle YOLOv8 Rust/WASM</title>
|
||||
</head>
|
||||
<body></body>
|
||||
</html>
|
||||
|
||||
<!doctype html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<style>
|
||||
@import url("https://fonts.googleapis.com/css2?family=Source+Code+Pro:wght@200;300;400&family=Source+Sans+3:wght@100;200;300;400;500;600;700;800;900&display=swap");
|
||||
html,
|
||||
body {
|
||||
font-family: "Source Sans 3", sans-serif;
|
||||
}
|
||||
code,
|
||||
output,
|
||||
select,
|
||||
pre {
|
||||
font-family: "Source Code Pro", monospace;
|
||||
}
|
||||
</style>
|
||||
<script src="https://cdn.tailwindcss.com"></script>
|
||||
<script
|
||||
src="https://cdn.jsdelivr.net/gh/huggingface/hub-js-utils/share-canvas.js"
|
||||
type="module"
|
||||
></script>
|
||||
<script type="module">
|
||||
const MODEL_BASEURL =
|
||||
"https://huggingface.co/lmz/candle-yolo-v8/resolve/main/";
|
||||
|
||||
const MODELS = {
|
||||
yolov8n: {
|
||||
model_size: "n",
|
||||
url: "yolov8n.safetensors",
|
||||
},
|
||||
yolov8s: {
|
||||
model_size: "s",
|
||||
url: "yolov8s.safetensors",
|
||||
},
|
||||
yolov8m: {
|
||||
model_size: "m",
|
||||
url: "yolov8m.safetensors",
|
||||
},
|
||||
yolov8l: {
|
||||
model_size: "l",
|
||||
url: "yolov8l.safetensors",
|
||||
},
|
||||
yolov8x: {
|
||||
model_size: "x",
|
||||
url: "yolov8x.safetensors",
|
||||
},
|
||||
};
|
||||
|
||||
// init web worker
|
||||
const yoloWorker = new Worker("./yoloWorker.js", { type: "module" });
|
||||
|
||||
let hasImage = false;
|
||||
//add event listener to image examples
|
||||
document.querySelector("#image-select").addEventListener("click", (e) => {
|
||||
const target = e.target;
|
||||
if (target.nodeName === "IMG") {
|
||||
const href = target.src;
|
||||
drawImageCanvas(href);
|
||||
}
|
||||
});
|
||||
//add event listener to file input
|
||||
document.querySelector("#file-upload").addEventListener("change", (e) => {
|
||||
const target = e.target;
|
||||
if (target.files.length > 0) {
|
||||
const href = URL.createObjectURL(target.files[0]);
|
||||
drawImageCanvas(href);
|
||||
}
|
||||
});
|
||||
// add event listener to drop-area
|
||||
const dropArea = document.querySelector("#drop-area");
|
||||
dropArea.addEventListener("dragenter", (e) => {
|
||||
e.preventDefault();
|
||||
dropArea.classList.add("border-blue-700");
|
||||
});
|
||||
dropArea.addEventListener("dragleave", (e) => {
|
||||
e.preventDefault();
|
||||
dropArea.classList.remove("border-blue-700");
|
||||
});
|
||||
dropArea.addEventListener("dragover", (e) => {
|
||||
e.preventDefault();
|
||||
});
|
||||
dropArea.addEventListener("drop", (e) => {
|
||||
e.preventDefault();
|
||||
dropArea.classList.remove("border-blue-700");
|
||||
const url = e.dataTransfer.getData("text/uri-list");
|
||||
const files = e.dataTransfer.files;
|
||||
|
||||
if (files.length > 0) {
|
||||
const href = URL.createObjectURL(files[0]);
|
||||
drawImageCanvas(href);
|
||||
} else if (url) {
|
||||
drawImageCanvas(url);
|
||||
}
|
||||
});
|
||||
|
||||
function drawImageCanvas(imgURL) {
|
||||
const canvas = document.querySelector("#canvas");
|
||||
const canvasResult = document.querySelector("#canvas-result");
|
||||
canvasResult
|
||||
.getContext("2d")
|
||||
.clearRect(0, 0, canvas.width, canvas.height);
|
||||
const ctx = canvas.getContext("2d");
|
||||
ctx.clearRect(0, 0, canvas.width, canvas.height);
|
||||
document.querySelector("#share-btn").hidden = true;
|
||||
|
||||
const img = new Image();
|
||||
img.crossOrigin = "anonymous";
|
||||
|
||||
img.onload = () => {
|
||||
canvas.width = img.width;
|
||||
canvas.height = img.height;
|
||||
ctx.drawImage(img, 0, 0);
|
||||
|
||||
canvas.parentElement.style.height = canvas.offsetHeight + "px";
|
||||
hasImage = true;
|
||||
document.querySelector("#detect").disabled = false;
|
||||
};
|
||||
img.src = imgURL;
|
||||
}
|
||||
|
||||
async function classifyImage(
|
||||
imageURL, // URL of image to classify
|
||||
modelID, // ID of model to use
|
||||
modelURL, // URL to model file
|
||||
modelSize, // size of model
|
||||
confidence, // confidence threshold
|
||||
iou_threshold, // IoU threshold
|
||||
updateStatus // function receives status updates
|
||||
) {
|
||||
return new Promise((resolve, reject) => {
|
||||
yoloWorker.postMessage({
|
||||
imageURL,
|
||||
modelID,
|
||||
modelURL,
|
||||
modelSize,
|
||||
confidence,
|
||||
iou_threshold,
|
||||
});
|
||||
yoloWorker.addEventListener("message", (event) => {
|
||||
if ("status" in event.data) {
|
||||
updateStatus(event.data.status);
|
||||
}
|
||||
if ("error" in event.data) {
|
||||
reject(new Error(event.data.error));
|
||||
}
|
||||
if (event.data.status === "complete") {
|
||||
resolve(event.data);
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
// add event listener to detect button
|
||||
document.querySelector("#detect").addEventListener("click", async () => {
|
||||
if (!hasImage) {
|
||||
return;
|
||||
}
|
||||
const modelID = document.querySelector("#model").value;
|
||||
const modelURL = MODEL_BASEURL + MODELS[modelID].url;
|
||||
const modelSize = MODELS[modelID].model_size;
|
||||
const confidence = parseFloat(
|
||||
document.querySelector("#confidence").value
|
||||
);
|
||||
const iou_threshold = parseFloat(
|
||||
document.querySelector("#iou_threshold").value
|
||||
);
|
||||
|
||||
const canvasInput = document.querySelector("#canvas");
|
||||
const canvas = document.querySelector("#canvas-result");
|
||||
canvas.width = canvasInput.width;
|
||||
canvas.height = canvasInput.height;
|
||||
|
||||
const scale = canvas.width / canvas.offsetWidth;
|
||||
|
||||
const ctx = canvas.getContext("2d");
|
||||
ctx.drawImage(canvasInput, 0, 0);
|
||||
const imageURL = canvas.toDataURL();
|
||||
|
||||
const results = await await classifyImage(
|
||||
imageURL,
|
||||
modelID,
|
||||
modelURL,
|
||||
modelSize,
|
||||
confidence,
|
||||
iou_threshold,
|
||||
updateStatus
|
||||
);
|
||||
|
||||
const { output } = results;
|
||||
|
||||
ctx.lineWidth = 1 + 2 * scale;
|
||||
ctx.strokeStyle = "#3c8566";
|
||||
ctx.fillStyle = "#0dff9a";
|
||||
const fontSize = 14 * scale;
|
||||
ctx.font = `${fontSize}px sans-serif`;
|
||||
for (const [label, bbox] of output) {
|
||||
const [x, y, w, h] = [
|
||||
bbox.xmin,
|
||||
bbox.ymin,
|
||||
bbox.xmax - bbox.xmin,
|
||||
bbox.ymax - bbox.ymin,
|
||||
];
|
||||
|
||||
const confidence = bbox.confidence;
|
||||
|
||||
const text = `${label} ${confidence.toFixed(2)}`;
|
||||
const width = ctx.measureText(text).width;
|
||||
ctx.fillStyle = "#3c8566";
|
||||
ctx.fillRect(x - 2, y - fontSize, width + 4, fontSize);
|
||||
ctx.fillStyle = "#e3fff3";
|
||||
|
||||
ctx.strokeRect(x, y, w, h);
|
||||
ctx.fillText(text, x, y - 2);
|
||||
}
|
||||
});
|
||||
|
||||
function updateStatus(statusMessage) {
|
||||
const button = document.querySelector("#detect");
|
||||
if (statusMessage === "detecting") {
|
||||
button.disabled = true;
|
||||
button.classList.add("bg-blue-700");
|
||||
button.classList.remove("bg-blue-950");
|
||||
button.textContent = "Detecting...";
|
||||
} else if (statusMessage === "complete") {
|
||||
button.disabled = false;
|
||||
button.classList.add("bg-blue-950");
|
||||
button.classList.remove("bg-blue-700");
|
||||
button.textContent = "Detect Objects";
|
||||
document.querySelector("#share-btn").hidden = false;
|
||||
}
|
||||
}
|
||||
document.querySelector("#share-btn").addEventListener("click", () => {
|
||||
shareToCommunity(
|
||||
"lmz/candle-yolo",
|
||||
"Candle + YOLOv8",
|
||||
"YOLOv8 with [Candle](https://github.com/huggingface/candle)",
|
||||
"canvas-result",
|
||||
"share-btn"
|
||||
);
|
||||
});
|
||||
</script>
|
||||
</head>
|
||||
<body class="container max-w-4xl mx-auto p-4">
|
||||
<main class="grid grid-cols-1 gap-8">
|
||||
<div>
|
||||
<h1 class="text-5xl font-bold">Candle YOLOv8</h1>
|
||||
<h2 class="text-2xl font-bold">Rust/WASM Demo</h2>
|
||||
<p class="max-w-lg">
|
||||
Running an object detection model in the browser using rust/wasm with
|
||||
an image. This demo uses the
|
||||
<a
|
||||
href="https://huggingface.co/lmz/candle-yolo-v8"
|
||||
target="_blank"
|
||||
class="underline hover:text-blue-500 hover:no-underline"
|
||||
>
|
||||
Candle YOLOv8
|
||||
</a>
|
||||
models to detect objects in images and WASM runtime built with
|
||||
<a
|
||||
href="https://github.com/huggingface/candle/"
|
||||
target="_blank"
|
||||
class="underline hover:text-blue-500 hover:no-underline"
|
||||
>Candle
|
||||
</a>
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<label for="model" class="font-medium">Models Options: </label>
|
||||
<select
|
||||
id="model"
|
||||
class="border-2 border-gray-500 rounded-md font-light"
|
||||
>
|
||||
<option value="yolov8n" selected>yolov8n (6.37 MB)</option>
|
||||
<option value="yolov8s">yolov8s (22.4 MB)</option>
|
||||
<option value="yolov8m">yolov8m (51.9 MB)</option>
|
||||
<option value="yolov8l">yolov8l (87.5 MB)</option>
|
||||
<option value="yolov8x">yolov8x (137 MB)</option>
|
||||
</select>
|
||||
</div>
|
||||
<!-- drag and drop area -->
|
||||
<div class="relative">
|
||||
<div
|
||||
id="drop-area"
|
||||
class="flex flex-col items-center justify-center border-2 border-gray-300 border-dashed rounded-xl relative aspect-video w-full overflow-hidden"
|
||||
>
|
||||
<div
|
||||
class="flex flex-col items-center justify-center space-y-1 text-center"
|
||||
>
|
||||
<svg
|
||||
width="25"
|
||||
height="25"
|
||||
viewBox="0 0 25 25"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
>
|
||||
<path
|
||||
d="M3.5 24.3a3 3 0 0 1-1.9-.8c-.5-.5-.8-1.2-.8-1.9V2.9c0-.7.3-1.3.8-1.9.6-.5 1.2-.7 2-.7h18.6c.7 0 1.3.2 1.9.7.5.6.7 1.2.7 2v18.6c0 .7-.2 1.4-.7 1.9a3 3 0 0 1-2 .8H3.6Zm0-2.7h18.7V2.9H3.5v18.7Zm2.7-2.7h13.3c.3 0 .5 0 .6-.3v-.7l-3.7-5a.6.6 0 0 0-.6-.2c-.2 0-.4 0-.5.3l-3.5 4.6-2.4-3.3a.6.6 0 0 0-.6-.3c-.2 0-.4.1-.5.3l-2.7 3.6c-.1.2-.2.4 0 .7.1.2.3.3.6.3Z"
|
||||
fill="#000"
|
||||
/>
|
||||
</svg>
|
||||
<div class="flex text-sm text-gray-600">
|
||||
<label
|
||||
for="file-upload"
|
||||
class="relative cursor-pointer bg-white rounded-md font-medium text-blue-950 hover:text-blue-700"
|
||||
>
|
||||
<span>Drag and drop your image here</span>
|
||||
<span class="block text-xs">or</span>
|
||||
<span class="block text-xs">Click to upload</span>
|
||||
</label>
|
||||
</div>
|
||||
<input
|
||||
id="file-upload"
|
||||
name="file-upload"
|
||||
type="file"
|
||||
class="sr-only"
|
||||
/>
|
||||
</div>
|
||||
<canvas
|
||||
id="canvas"
|
||||
class="absolute pointer-events-none w-full"
|
||||
></canvas>
|
||||
<canvas
|
||||
id="canvas-result"
|
||||
class="absolute pointer-events-none w-full"
|
||||
></canvas>
|
||||
</div>
|
||||
<div class="text-right py-2">
|
||||
<button
|
||||
id="share-btn"
|
||||
hidden
|
||||
class="bg-white rounded-md hover:outline outline-orange-200 disabled:opacity-50"
|
||||
>
|
||||
<img
|
||||
src="https://huggingface.co/datasets/huggingface/badges/raw/main/share-to-community-sm.svg"
|
||||
/>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
<div>
|
||||
<div class="flex gap-3 items-center" id="image-select">
|
||||
<h3 class="font-medium">Examples:</h3>
|
||||
|
||||
<img
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/candle/examples/sf.jpg"
|
||||
class="cursor-pointer w-24 h-24 object-cover"
|
||||
/>
|
||||
<img
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/candle/examples/bike.jpeg"
|
||||
class="cursor-pointer w-24 h-24 object-cover"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<div>
|
||||
<div class="grid grid-cols-3 max-w-md items-center gap-3">
|
||||
<label class="text-sm font-medium" for="confidence"
|
||||
>Confidence Threshold</label
|
||||
>
|
||||
<input
|
||||
type="range"
|
||||
id="confidence"
|
||||
name="confidence"
|
||||
min="0"
|
||||
max="1"
|
||||
step="0.01"
|
||||
value="0.25"
|
||||
oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)"
|
||||
/>
|
||||
<output
|
||||
class="text-xs font-light px-1 py-1 border border-gray-700 rounded-md w-min"
|
||||
>0.25</output
|
||||
>
|
||||
|
||||
<label class="text-sm font-medium" for="iou_threshold"
|
||||
>IoU Threshold</label
|
||||
>
|
||||
|
||||
<input
|
||||
type="range"
|
||||
id="iou_threshold"
|
||||
name="iou_threshold"
|
||||
min="0"
|
||||
max="1"
|
||||
step="0.01"
|
||||
value="0.45"
|
||||
oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)"
|
||||
/>
|
||||
<output
|
||||
class="font-extralight text-xs px-1 py-1 border border-gray-700 rounded-md w-min"
|
||||
>0.45</output
|
||||
>
|
||||
</div>
|
||||
</div>
|
||||
<div>
|
||||
<button
|
||||
id="detect"
|
||||
disabled
|
||||
class="bg-blue-950 hover:bg-blue-700 text-white font-normal py-2 px-4 rounded disabled:opacity-75 disabled:hover:bg-blue-950"
|
||||
>
|
||||
Detect Objects
|
||||
</button>
|
||||
</div>
|
||||
</main>
|
||||
</body>
|
||||
</html>
|
@ -1,5 +1,5 @@
|
||||
use crate::console_log;
|
||||
use crate::worker::{ModelData, Worker, WorkerInput, WorkerOutput};
|
||||
use crate::worker::{ModelData, RunData, Worker, WorkerInput, WorkerOutput};
|
||||
use wasm_bindgen::prelude::*;
|
||||
use wasm_bindgen_futures::JsFuture;
|
||||
use yew::{html, Component, Context, Html};
|
||||
@ -50,9 +50,13 @@ pub struct App {
|
||||
}
|
||||
|
||||
async fn model_data_load() -> Result<ModelData, JsValue> {
|
||||
let weights = fetch_url("yolo.safetensors").await?;
|
||||
let weights = fetch_url("yolov8s.safetensors").await?;
|
||||
let model_size = "s".to_string();
|
||||
console_log!("loaded weights {}", weights.len());
|
||||
Ok(ModelData { weights })
|
||||
Ok(ModelData {
|
||||
weights,
|
||||
model_size,
|
||||
})
|
||||
}
|
||||
|
||||
fn performance_now() -> Option<f64> {
|
||||
@ -162,7 +166,11 @@ impl Component for App {
|
||||
let status = format!("{err:?}");
|
||||
Msg::UpdateStatus(status)
|
||||
}
|
||||
Ok(image_data) => Msg::WorkerInMsg(WorkerInput::Run(image_data)),
|
||||
Ok(image_data) => Msg::WorkerInMsg(WorkerInput::RunData(RunData {
|
||||
image_data,
|
||||
conf_threshold: 0.5,
|
||||
iou_threshold: 0.5,
|
||||
})),
|
||||
}
|
||||
});
|
||||
}
|
||||
|
@ -1,3 +1,5 @@
|
||||
use candle_wasm_example_yolo::coco_classes;
|
||||
use candle_wasm_example_yolo::model::Bbox;
|
||||
use candle_wasm_example_yolo::worker::Model as M;
|
||||
use wasm_bindgen::prelude::*;
|
||||
|
||||
@ -9,15 +11,36 @@ pub struct Model {
|
||||
#[wasm_bindgen]
|
||||
impl Model {
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(data: Vec<u8>) -> Result<Model, JsError> {
|
||||
let inner = M::load_(&data)?;
|
||||
pub fn new(data: Vec<u8>, model_size: &str) -> Result<Model, JsError> {
|
||||
let inner = M::load_(&data, model_size)?;
|
||||
Ok(Self { inner })
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
pub fn run(&self, image: Vec<u8>) -> Result<String, JsError> {
|
||||
let boxes = self.inner.run(image)?;
|
||||
let json = serde_json::to_string(&boxes)?;
|
||||
pub fn run(
|
||||
&self,
|
||||
image: Vec<u8>,
|
||||
conf_threshold: f32,
|
||||
iou_threshold: f32,
|
||||
) -> Result<String, JsError> {
|
||||
let bboxes = self.inner.run(image, conf_threshold, iou_threshold)?;
|
||||
let mut detections: Vec<(String, Bbox)> = vec![];
|
||||
|
||||
for (class_index, bboxes_for_class) in bboxes.iter().enumerate() {
|
||||
for b in bboxes_for_class.iter() {
|
||||
detections.push((
|
||||
coco_classes::NAMES[class_index].to_string(),
|
||||
Bbox {
|
||||
xmin: b.xmin,
|
||||
ymin: b.ymin,
|
||||
xmax: b.xmax,
|
||||
ymax: b.ymax,
|
||||
confidence: b.confidence,
|
||||
},
|
||||
));
|
||||
}
|
||||
}
|
||||
let json = serde_json::to_string(&detections)?;
|
||||
Ok(json)
|
||||
}
|
||||
}
|
||||
|
@ -1,6 +1,6 @@
|
||||
mod app;
|
||||
mod coco_classes;
|
||||
mod model;
|
||||
pub mod coco_classes;
|
||||
pub mod model;
|
||||
pub mod worker;
|
||||
pub use app::App;
|
||||
pub use worker::Worker;
|
||||
|
@ -623,16 +623,25 @@ fn iou(b1: &Bbox, b2: &Bbox) -> f32 {
|
||||
i_area / (b1_area + b2_area - i_area)
|
||||
}
|
||||
|
||||
pub fn report(pred: &Tensor, img: DynamicImage, w: usize, h: usize) -> Result<Vec<Vec<Bbox>>> {
|
||||
pub fn report(
|
||||
pred: &Tensor,
|
||||
img: DynamicImage,
|
||||
w: usize,
|
||||
h: usize,
|
||||
conf_threshold: f32,
|
||||
iou_threshold: f32,
|
||||
) -> Result<Vec<Vec<Bbox>>> {
|
||||
let (pred_size, npreds) = pred.dims2()?;
|
||||
let nclasses = pred_size - 4;
|
||||
let conf_threshold = conf_threshold.clamp(0.0, 1.0);
|
||||
let iou_threshold = iou_threshold.clamp(0.0, 1.0);
|
||||
// The bounding boxes grouped by (maximum) class index.
|
||||
let mut bboxes: Vec<Vec<Bbox>> = (0..nclasses).map(|_| vec![]).collect();
|
||||
// Extract the bounding boxes for which confidence is above the threshold.
|
||||
for index in 0..npreds {
|
||||
let pred = Vec::<f32>::try_from(pred.i((.., index))?)?;
|
||||
let confidence = *pred[4..].iter().max_by(|x, y| x.total_cmp(y)).unwrap();
|
||||
if confidence > CONFIDENCE_THRESHOLD {
|
||||
if confidence > conf_threshold {
|
||||
let mut class_index = 0;
|
||||
for i in 0..nclasses {
|
||||
if pred[4 + i] > pred[4 + class_index] {
|
||||
@ -659,7 +668,7 @@ pub fn report(pred: &Tensor, img: DynamicImage, w: usize, h: usize) -> Result<Ve
|
||||
let mut drop = false;
|
||||
for prev_index in 0..current_index {
|
||||
let iou = iou(&bboxes_for_class[prev_index], &bboxes_for_class[index]);
|
||||
if iou > NMS_THRESHOLD {
|
||||
if iou > iou_threshold {
|
||||
drop = true;
|
||||
break;
|
||||
}
|
||||
|
@ -25,6 +25,14 @@ macro_rules! console_log {
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct ModelData {
|
||||
pub weights: Vec<u8>,
|
||||
pub model_size: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct RunData {
|
||||
pub image_data: Vec<u8>,
|
||||
pub conf_threshold: f32,
|
||||
pub iou_threshold: f32,
|
||||
}
|
||||
|
||||
pub struct Model {
|
||||
@ -32,7 +40,12 @@ pub struct Model {
|
||||
}
|
||||
|
||||
impl Model {
|
||||
pub fn run(&self, image_data: Vec<u8>) -> Result<Vec<Vec<Bbox>>> {
|
||||
pub fn run(
|
||||
&self,
|
||||
image_data: Vec<u8>,
|
||||
conf_threshold: f32,
|
||||
iou_threshold: f32,
|
||||
) -> Result<Vec<Vec<Bbox>>> {
|
||||
console_log!("image data: {}", image_data.len());
|
||||
let image_data = std::io::Cursor::new(image_data);
|
||||
let original_image = image::io::Reader::new(image_data)
|
||||
@ -68,20 +81,37 @@ impl Model {
|
||||
let image_t = (image_t.unsqueeze(0)?.to_dtype(DType::F32)? * (1. / 255.))?;
|
||||
let predictions = self.model.forward(&image_t)?.squeeze(0)?;
|
||||
console_log!("generated predictions {predictions:?}");
|
||||
let bboxes = report(&predictions, original_image, width, height)?;
|
||||
let bboxes = report(
|
||||
&predictions,
|
||||
original_image,
|
||||
width,
|
||||
height,
|
||||
conf_threshold,
|
||||
iou_threshold,
|
||||
)?;
|
||||
Ok(bboxes)
|
||||
}
|
||||
|
||||
pub fn load_(weights: &[u8]) -> Result<Self> {
|
||||
pub fn load_(weights: &[u8], model_size: &str) -> Result<Self> {
|
||||
let multiples = match model_size {
|
||||
"n" => Multiples::n(),
|
||||
"s" => Multiples::s(),
|
||||
"m" => Multiples::m(),
|
||||
"l" => Multiples::l(),
|
||||
"x" => Multiples::x(),
|
||||
_ => Err(candle::Error::Msg(
|
||||
"invalid model size: must be n, s, m, l or x".to_string(),
|
||||
))?,
|
||||
};
|
||||
let dev = &Device::Cpu;
|
||||
let weights = safetensors::tensor::SafeTensors::deserialize(weights)?;
|
||||
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, dev);
|
||||
let model = YoloV8::load(vb, Multiples::s(), 80)?;
|
||||
let model = YoloV8::load(vb, multiples, 80)?;
|
||||
Ok(Self { model })
|
||||
}
|
||||
|
||||
pub fn load(md: ModelData) -> Result<Self> {
|
||||
Self::load_(&md.weights)
|
||||
Self::load_(&md.weights, &md.model_size.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
@ -93,7 +123,7 @@ pub struct Worker {
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub enum WorkerInput {
|
||||
ModelData(ModelData),
|
||||
Run(Vec<u8>),
|
||||
RunData(RunData),
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
@ -125,10 +155,12 @@ impl yew_agent::Worker for Worker {
|
||||
}
|
||||
Err(err) => Err(format!("model creation error {err:?}")),
|
||||
},
|
||||
WorkerInput::Run(image_data) => match &mut self.model {
|
||||
WorkerInput::RunData(rd) => match &mut self.model {
|
||||
None => Err("model has not been set yet".to_string()),
|
||||
Some(model) => {
|
||||
let result = model.run(image_data).map_err(|e| e.to_string());
|
||||
let result = model
|
||||
.run(rd.image_data, rd.conf_threshold, rd.iou_threshold)
|
||||
.map_err(|e| e.to_string());
|
||||
Ok(WorkerOutput::ProcessingDone(result))
|
||||
}
|
||||
},
|
||||
|
49
candle-wasm-examples/yolo/yoloWorker.js
Normal file
49
candle-wasm-examples/yolo/yoloWorker.js
Normal file
@ -0,0 +1,49 @@
|
||||
//load the candle yolo wasm module
|
||||
import init, { Model } from "./build/m.js";
|
||||
|
||||
class Yolo {
|
||||
static instance = {};
|
||||
// Retrieve the YOLO model. When called for the first time,
|
||||
// this will load the model and save it for future use.
|
||||
static async getInstance(modelID, modelURL, modelSize) {
|
||||
// load individual modelID only once
|
||||
if (!this.instance[modelID]) {
|
||||
await init();
|
||||
|
||||
self.postMessage({ status: `loading model ${modelID}:${modelSize}` });
|
||||
const modelRes = await fetch(modelURL);
|
||||
const yoloArrayBuffer = await modelRes.arrayBuffer();
|
||||
const weightsArrayU8 = new Uint8Array(yoloArrayBuffer);
|
||||
this.instance[modelID] = new Model(weightsArrayU8, modelSize);
|
||||
} else {
|
||||
self.postMessage({ status: "model already loaded" });
|
||||
}
|
||||
return this.instance[modelID];
|
||||
}
|
||||
}
|
||||
|
||||
self.addEventListener("message", async (event) => {
|
||||
const { imageURL, modelID, modelURL, modelSize, confidence, iou_threshold } =
|
||||
event.data;
|
||||
try {
|
||||
self.postMessage({ status: "detecting" });
|
||||
|
||||
const yolo = await Yolo.getInstance(modelID, modelURL, modelSize);
|
||||
|
||||
self.postMessage({ status: "loading image" });
|
||||
const imgRes = await fetch(imageURL);
|
||||
const imgData = await imgRes.arrayBuffer();
|
||||
const imageArrayU8 = new Uint8Array(imgData);
|
||||
|
||||
self.postMessage({ status: `running inference ${modelID}:${modelSize}` });
|
||||
const bboxes = yolo.run(imageArrayU8, confidence, iou_threshold);
|
||||
|
||||
// Send the output back to the main thread as JSON
|
||||
self.postMessage({
|
||||
status: "complete",
|
||||
output: JSON.parse(bboxes),
|
||||
});
|
||||
} catch (e) {
|
||||
self.postMessage({ error: e });
|
||||
}
|
||||
});
|
Reference in New Issue
Block a user