diff --git a/.gitignore b/.gitignore index b313c3ec..85dc61c0 100644 --- a/.gitignore +++ b/.gitignore @@ -31,3 +31,5 @@ candle-wasm-examples/*/*.jpeg candle-wasm-examples/*/*.wav candle-wasm-examples/*/*.safetensors candle-wasm-examples/*/package-lock.json + +.DS_Store \ No newline at end of file diff --git a/candle-wasm-examples/yolo/README.md b/candle-wasm-examples/yolo/README.md new file mode 100644 index 00000000..37c4c99b --- /dev/null +++ b/candle-wasm-examples/yolo/README.md @@ -0,0 +1,44 @@ +## Running Yolo Examples + +Here, we provide two examples of how to run YOLOv8 using a Candle-compiled WASM binary and runtimes. + +### Pure Rust UI + +To build and test the UI made in Rust you will need [Trunk](https://trunkrs.dev/#install) +From the `candle-wasm-examples/yolo` directory run: + +Download assets: + +```bash +wget -c https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/candle/examples/bike.jpeg +wget -c https://huggingface.co/lmz/candle-yolo-v8/resolve/main/yolov8s.safetensors +``` + +Run hot reload server: + +```bash +trunk serve --release --public-url / --port 8080 +``` + +### Vanilla JS and WebWorkers + +To build and test the UI made in Vanilla JS and WebWorkers, first we need to build the WASM library: + +```bash +sh build-lib.sh +``` + +This will bundle the library under `./build` and we can import it inside our WebWorker like a normal JS module: + +```js +import init, { Model } from "./build/m.js"; +``` + +The full example can be found under `./lib-example.html`. All needed assets are fetched from the web, so no need to download anything. +Finally, you can preview the example by running a local HTTP server. For example: + +```bash +python -m http.server +``` + +Then open `http://localhost:8000/lib-example.html` in your browser. diff --git a/candle-wasm-examples/yolo/index.html b/candle-wasm-examples/yolo/index.html index c64051ee..285c02ab 100644 --- a/candle-wasm-examples/yolo/index.html +++ b/candle-wasm-examples/yolo/index.html @@ -4,7 +4,7 @@ Welcome to Candle! - + diff --git a/candle-wasm-examples/yolo/lib-example.html b/candle-wasm-examples/yolo/lib-example.html new file mode 100644 index 00000000..23c94d38 --- /dev/null +++ b/candle-wasm-examples/yolo/lib-example.html @@ -0,0 +1,414 @@ + + + + Candle YOLOv8 Rust/WASM + + + + + + + + + + + + + + + +
+
+

Candle YOLOv8

+

Rust/WASM Demo

+

+ Running an object detection model in the browser using rust/wasm with + an image. This demo uses the + + Candle YOLOv8 + + models to detect objects in images and WASM runtime built with + Candle + +

+
+ +
+ + +
+ +
+
+
+ + + +
+ +
+ +
+ + +
+
+ +
+
+
+
+

Examples:

+ + + +
+
+
+
+ + + 0.25 + + + + + 0.45 +
+
+
+ +
+
+ + diff --git a/candle-wasm-examples/yolo/src/app.rs b/candle-wasm-examples/yolo/src/app.rs index bd999b6c..0df61f0f 100644 --- a/candle-wasm-examples/yolo/src/app.rs +++ b/candle-wasm-examples/yolo/src/app.rs @@ -1,5 +1,5 @@ use crate::console_log; -use crate::worker::{ModelData, Worker, WorkerInput, WorkerOutput}; +use crate::worker::{ModelData, RunData, Worker, WorkerInput, WorkerOutput}; use wasm_bindgen::prelude::*; use wasm_bindgen_futures::JsFuture; use yew::{html, Component, Context, Html}; @@ -50,9 +50,13 @@ pub struct App { } async fn model_data_load() -> Result { - let weights = fetch_url("yolo.safetensors").await?; + let weights = fetch_url("yolov8s.safetensors").await?; + let model_size = "s".to_string(); console_log!("loaded weights {}", weights.len()); - Ok(ModelData { weights }) + Ok(ModelData { + weights, + model_size, + }) } fn performance_now() -> Option { @@ -162,7 +166,11 @@ impl Component for App { let status = format!("{err:?}"); Msg::UpdateStatus(status) } - Ok(image_data) => Msg::WorkerInMsg(WorkerInput::Run(image_data)), + Ok(image_data) => Msg::WorkerInMsg(WorkerInput::RunData(RunData { + image_data, + conf_threshold: 0.5, + iou_threshold: 0.5, + })), } }); } diff --git a/candle-wasm-examples/yolo/src/bin/m.rs b/candle-wasm-examples/yolo/src/bin/m.rs index e0c8fc5d..fac32d83 100644 --- a/candle-wasm-examples/yolo/src/bin/m.rs +++ b/candle-wasm-examples/yolo/src/bin/m.rs @@ -1,3 +1,5 @@ +use candle_wasm_example_yolo::coco_classes; +use candle_wasm_example_yolo::model::Bbox; use candle_wasm_example_yolo::worker::Model as M; use wasm_bindgen::prelude::*; @@ -9,15 +11,36 @@ pub struct Model { #[wasm_bindgen] impl Model { #[wasm_bindgen(constructor)] - pub fn new(data: Vec) -> Result { - let inner = M::load_(&data)?; + pub fn new(data: Vec, model_size: &str) -> Result { + let inner = M::load_(&data, model_size)?; Ok(Self { inner }) } #[wasm_bindgen] - pub fn run(&self, image: Vec) -> Result { - let boxes = self.inner.run(image)?; - let json = serde_json::to_string(&boxes)?; + pub fn run( + &self, + image: Vec, + conf_threshold: f32, + iou_threshold: f32, + ) -> Result { + let bboxes = self.inner.run(image, conf_threshold, iou_threshold)?; + let mut detections: Vec<(String, Bbox)> = vec![]; + + for (class_index, bboxes_for_class) in bboxes.iter().enumerate() { + for b in bboxes_for_class.iter() { + detections.push(( + coco_classes::NAMES[class_index].to_string(), + Bbox { + xmin: b.xmin, + ymin: b.ymin, + xmax: b.xmax, + ymax: b.ymax, + confidence: b.confidence, + }, + )); + } + } + let json = serde_json::to_string(&detections)?; Ok(json) } } diff --git a/candle-wasm-examples/yolo/src/lib.rs b/candle-wasm-examples/yolo/src/lib.rs index c691987b..524f4a19 100644 --- a/candle-wasm-examples/yolo/src/lib.rs +++ b/candle-wasm-examples/yolo/src/lib.rs @@ -1,6 +1,6 @@ mod app; -mod coco_classes; -mod model; +pub mod coco_classes; +pub mod model; pub mod worker; pub use app::App; pub use worker::Worker; diff --git a/candle-wasm-examples/yolo/src/model.rs b/candle-wasm-examples/yolo/src/model.rs index 7e40fcfc..923e4d87 100644 --- a/candle-wasm-examples/yolo/src/model.rs +++ b/candle-wasm-examples/yolo/src/model.rs @@ -623,16 +623,25 @@ fn iou(b1: &Bbox, b2: &Bbox) -> f32 { i_area / (b1_area + b2_area - i_area) } -pub fn report(pred: &Tensor, img: DynamicImage, w: usize, h: usize) -> Result>> { +pub fn report( + pred: &Tensor, + img: DynamicImage, + w: usize, + h: usize, + conf_threshold: f32, + iou_threshold: f32, +) -> Result>> { let (pred_size, npreds) = pred.dims2()?; let nclasses = pred_size - 4; + let conf_threshold = conf_threshold.clamp(0.0, 1.0); + let iou_threshold = iou_threshold.clamp(0.0, 1.0); // The bounding boxes grouped by (maximum) class index. let mut bboxes: Vec> = (0..nclasses).map(|_| vec![]).collect(); // Extract the bounding boxes for which confidence is above the threshold. for index in 0..npreds { let pred = Vec::::try_from(pred.i((.., index))?)?; let confidence = *pred[4..].iter().max_by(|x, y| x.total_cmp(y)).unwrap(); - if confidence > CONFIDENCE_THRESHOLD { + if confidence > conf_threshold { let mut class_index = 0; for i in 0..nclasses { if pred[4 + i] > pred[4 + class_index] { @@ -659,7 +668,7 @@ pub fn report(pred: &Tensor, img: DynamicImage, w: usize, h: usize) -> Result NMS_THRESHOLD { + if iou > iou_threshold { drop = true; break; } diff --git a/candle-wasm-examples/yolo/src/worker.rs b/candle-wasm-examples/yolo/src/worker.rs index f6232b3c..0f4cd6f2 100644 --- a/candle-wasm-examples/yolo/src/worker.rs +++ b/candle-wasm-examples/yolo/src/worker.rs @@ -25,6 +25,14 @@ macro_rules! console_log { #[derive(Serialize, Deserialize)] pub struct ModelData { pub weights: Vec, + pub model_size: String, +} + +#[derive(Serialize, Deserialize)] +pub struct RunData { + pub image_data: Vec, + pub conf_threshold: f32, + pub iou_threshold: f32, } pub struct Model { @@ -32,7 +40,12 @@ pub struct Model { } impl Model { - pub fn run(&self, image_data: Vec) -> Result>> { + pub fn run( + &self, + image_data: Vec, + conf_threshold: f32, + iou_threshold: f32, + ) -> Result>> { console_log!("image data: {}", image_data.len()); let image_data = std::io::Cursor::new(image_data); let original_image = image::io::Reader::new(image_data) @@ -68,20 +81,37 @@ impl Model { let image_t = (image_t.unsqueeze(0)?.to_dtype(DType::F32)? * (1. / 255.))?; let predictions = self.model.forward(&image_t)?.squeeze(0)?; console_log!("generated predictions {predictions:?}"); - let bboxes = report(&predictions, original_image, width, height)?; + let bboxes = report( + &predictions, + original_image, + width, + height, + conf_threshold, + iou_threshold, + )?; Ok(bboxes) } - pub fn load_(weights: &[u8]) -> Result { + pub fn load_(weights: &[u8], model_size: &str) -> Result { + let multiples = match model_size { + "n" => Multiples::n(), + "s" => Multiples::s(), + "m" => Multiples::m(), + "l" => Multiples::l(), + "x" => Multiples::x(), + _ => Err(candle::Error::Msg( + "invalid model size: must be n, s, m, l or x".to_string(), + ))?, + }; let dev = &Device::Cpu; let weights = safetensors::tensor::SafeTensors::deserialize(weights)?; let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, dev); - let model = YoloV8::load(vb, Multiples::s(), 80)?; + let model = YoloV8::load(vb, multiples, 80)?; Ok(Self { model }) } pub fn load(md: ModelData) -> Result { - Self::load_(&md.weights) + Self::load_(&md.weights, &md.model_size.to_string()) } } @@ -93,7 +123,7 @@ pub struct Worker { #[derive(Serialize, Deserialize)] pub enum WorkerInput { ModelData(ModelData), - Run(Vec), + RunData(RunData), } #[derive(Serialize, Deserialize)] @@ -125,10 +155,12 @@ impl yew_agent::Worker for Worker { } Err(err) => Err(format!("model creation error {err:?}")), }, - WorkerInput::Run(image_data) => match &mut self.model { + WorkerInput::RunData(rd) => match &mut self.model { None => Err("model has not been set yet".to_string()), Some(model) => { - let result = model.run(image_data).map_err(|e| e.to_string()); + let result = model + .run(rd.image_data, rd.conf_threshold, rd.iou_threshold) + .map_err(|e| e.to_string()); Ok(WorkerOutput::ProcessingDone(result)) } }, diff --git a/candle-wasm-examples/yolo/yoloWorker.js b/candle-wasm-examples/yolo/yoloWorker.js new file mode 100644 index 00000000..0e45ed6f --- /dev/null +++ b/candle-wasm-examples/yolo/yoloWorker.js @@ -0,0 +1,49 @@ +//load the candle yolo wasm module +import init, { Model } from "./build/m.js"; + +class Yolo { + static instance = {}; + // Retrieve the YOLO model. When called for the first time, + // this will load the model and save it for future use. + static async getInstance(modelID, modelURL, modelSize) { + // load individual modelID only once + if (!this.instance[modelID]) { + await init(); + + self.postMessage({ status: `loading model ${modelID}:${modelSize}` }); + const modelRes = await fetch(modelURL); + const yoloArrayBuffer = await modelRes.arrayBuffer(); + const weightsArrayU8 = new Uint8Array(yoloArrayBuffer); + this.instance[modelID] = new Model(weightsArrayU8, modelSize); + } else { + self.postMessage({ status: "model already loaded" }); + } + return this.instance[modelID]; + } +} + +self.addEventListener("message", async (event) => { + const { imageURL, modelID, modelURL, modelSize, confidence, iou_threshold } = + event.data; + try { + self.postMessage({ status: "detecting" }); + + const yolo = await Yolo.getInstance(modelID, modelURL, modelSize); + + self.postMessage({ status: "loading image" }); + const imgRes = await fetch(imageURL); + const imgData = await imgRes.arrayBuffer(); + const imageArrayU8 = new Uint8Array(imgData); + + self.postMessage({ status: `running inference ${modelID}:${modelSize}` }); + const bboxes = yolo.run(imageArrayU8, confidence, iou_threshold); + + // Send the output back to the main thread as JSON + self.postMessage({ + status: "complete", + output: JSON.parse(bboxes), + }); + } catch (e) { + self.postMessage({ error: e }); + } +});