diff --git a/candle-wasm-examples/yolo/README.md b/candle-wasm-examples/yolo/README.md index 37c4c99b..62f35dfb 100644 --- a/candle-wasm-examples/yolo/README.md +++ b/candle-wasm-examples/yolo/README.md @@ -31,7 +31,7 @@ sh build-lib.sh This will bundle the library under `./build` and we can import it inside our WebWorker like a normal JS module: ```js -import init, { Model } from "./build/m.js"; +import init, { Model, ModelPose } from "./build/m.js"; ``` The full example can be found under `./lib-example.html`. All needed assets are fetched from the web, so no need to download anything. diff --git a/candle-wasm-examples/yolo/lib-example.html b/candle-wasm-examples/yolo/lib-example.html index 23c94d38..bab2ec13 100644 --- a/candle-wasm-examples/yolo/lib-example.html +++ b/candle-wasm-examples/yolo/lib-example.html @@ -54,8 +54,50 @@ model_size: "x", url: "yolov8x.safetensors", }, + yolov8n_pose: { + model_size: "n", + url: "yolov8n-pose.safetensors", + }, + yolov8s_pose: { + model_size: "s", + url: "yolov8s-pose.safetensors", + }, + yolov8m_pose: { + model_size: "m", + url: "yolov8m-pose.safetensors", + }, + yolov8l_pose: { + model_size: "l", + url: "yolov8l-pose.safetensors", + }, + yolov8x_pose: { + model_size: "x", + url: "yolov8x-pose.safetensors", + }, }; + const COCO_PERSON_SKELETON = [ + [4, 0], // head + [3, 0], + [16, 14], // left lower leg + [14, 12], // left upper leg + [6, 12], // left torso + [6, 5], // top torso + [6, 8], // upper arm + [8, 10], // lower arm + [1, 2], // head + [1, 3], // right head + [2, 4], // left head + [3, 5], // right neck + [4, 6], // left neck + [5, 7], // right upper arm + [7, 9], // right lower arm + [5, 11], // right torso + [11, 12], // bottom torso + [11, 13], // right upper leg + [13, 15], // right lower leg + ]; + // init web worker const yoloWorker = new Worker("./yoloWorker.js", { type: "module" }); @@ -202,17 +244,28 @@ ctx.fillStyle = "#0dff9a"; const fontSize = 14 * scale; ctx.font = `${fontSize}px sans-serif`; - for (const [label, bbox] of output) { - const [x, y, w, h] = [ - bbox.xmin, - bbox.ymin, - bbox.xmax - bbox.xmin, - bbox.ymax - bbox.ymin, - ]; + for (const detection of output) { + // check keypoint for pose model data + let xmin, xmax, ymin, ymax, label, confidence, keypoints; + if ("keypoints" in detection) { + xmin = detection.xmin; + xmax = detection.xmax; + ymin = detection.ymin; + ymax = detection.ymax; + confidence = detection.confidence; + keypoints = detection.keypoints; + } else { + const [_label, bbox] = detection; + label = _label; + xmin = bbox.xmin; + xmax = bbox.xmax; + ymin = bbox.ymin; + ymax = bbox.ymax; + confidence = bbox.confidence; + } + const [x, y, w, h] = [xmin, ymin, xmax - xmin, ymax - ymin]; - const confidence = bbox.confidence; - - const text = `${label} ${confidence.toFixed(2)}`; + const text = `${label ? label + " " : ""}${confidence.toFixed(2)}`; const width = ctx.measureText(text).width; ctx.fillStyle = "#3c8566"; ctx.fillRect(x - 2, y - fontSize, width + 4, fontSize); @@ -220,6 +273,28 @@ ctx.strokeRect(x, y, w, h); ctx.fillText(text, x, y - 2); + if (keypoints) { + ctx.save(); + ctx.fillStyle = "magenta"; + ctx.strokeStyle = "yellow"; + + for (const keypoint of keypoints) { + const { x, y } = keypoint; + ctx.beginPath(); + ctx.arc(x, y, 3, 0, 2 * Math.PI); + ctx.fill(); + } + ctx.beginPath(); + for (const [xid, yid] of COCO_PERSON_SKELETON) { + //draw line between skeleton keypoitns + if (keypoints[xid] && keypoints[yid]) { + ctx.moveTo(keypoints[xid].x, keypoints[xid].y); + ctx.lineTo(keypoints[yid].x, keypoints[yid].y); + } + } + ctx.stroke(); + ctx.restore(); + } } }); @@ -229,12 +304,12 @@ button.disabled = true; button.classList.add("bg-blue-700"); button.classList.remove("bg-blue-950"); - button.textContent = "Detecting..."; + button.textContent = "Predicting..."; } else if (statusMessage === "complete") { button.disabled = false; button.classList.add("bg-blue-950"); button.classList.remove("bg-blue-700"); - button.textContent = "Detect Objects"; + button.textContent = "Predict"; document.querySelector("#share-btn").hidden = false; } } @@ -250,27 +325,31 @@ -
+
+ 🕯️

Candle YOLOv8

Rust/WASM Demo

- Running an object detection model in the browser using rust/wasm with - an image. This demo uses the + This demo showcases object detection and pose estimation models in + your browser using Rust/WASM. It utilizes - Candle YOLOv8 + safetensor's YOLOv8 models - models to detect objects in images and WASM runtime built with + and a WASM runtime built with Candle - + >Candle . +

+

+ To run pose estimation, select a yolo pose model from the dropdown

@@ -285,6 +364,12 @@ + + + + + + @@ -358,6 +443,10 @@ src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/candle/examples/bike.jpeg" class="cursor-pointer w-24 h-24 object-cover" /> +
@@ -406,7 +495,7 @@ disabled class="bg-blue-950 hover:bg-blue-700 text-white font-normal py-2 px-4 rounded disabled:opacity-75 disabled:hover:bg-blue-950" > - Detect Objects + Predict
diff --git a/candle-wasm-examples/yolo/yoloWorker.js b/candle-wasm-examples/yolo/yoloWorker.js index 0e45ed6f..93097372 100644 --- a/candle-wasm-examples/yolo/yoloWorker.js +++ b/candle-wasm-examples/yolo/yoloWorker.js @@ -1,5 +1,5 @@ //load the candle yolo wasm module -import init, { Model } from "./build/m.js"; +import init, { Model, ModelPose } from "./build/m.js"; class Yolo { static instance = {}; @@ -14,7 +14,12 @@ class Yolo { const modelRes = await fetch(modelURL); const yoloArrayBuffer = await modelRes.arrayBuffer(); const weightsArrayU8 = new Uint8Array(yoloArrayBuffer); - this.instance[modelID] = new Model(weightsArrayU8, modelSize); + if (/pose/.test(modelID)) { + // if pose model, use ModelPose + this.instance[modelID] = new ModelPose(weightsArrayU8, modelSize); + } else { + this.instance[modelID] = new Model(weightsArrayU8, modelSize); + } } else { self.postMessage({ status: "model already loaded" }); }