From 7687a0f4532544946538a5a42a6aa820c6a2c7e4 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Tue, 22 Aug 2023 22:20:08 +0100 Subject: [PATCH] Also fix the aspect ratio in the wasm example. (#556) * Also fix the aspect ratio in the wasm example. * Add the yolo lib. * Update the build script. --- .gitignore | 1 + candle-wasm-examples/yolo/build-lib.sh | 2 + candle-wasm-examples/yolo/src/bin/m.rs | 25 +++++++++++ candle-wasm-examples/yolo/src/lib.rs | 2 +- candle-wasm-examples/yolo/src/model.rs | 4 +- candle-wasm-examples/yolo/src/worker.rs | 60 +++++++++++++++---------- 6 files changed, 68 insertions(+), 26 deletions(-) create mode 100644 candle-wasm-examples/yolo/build-lib.sh create mode 100644 candle-wasm-examples/yolo/src/bin/m.rs diff --git a/.gitignore b/.gitignore index 7269ec8a..b313c3ec 100644 --- a/.gitignore +++ b/.gitignore @@ -25,6 +25,7 @@ flamegraph.svg *.swp trace-*.json +candle-wasm-examples/*/build candle-wasm-examples/*/*.bin candle-wasm-examples/*/*.jpeg candle-wasm-examples/*/*.wav diff --git a/candle-wasm-examples/yolo/build-lib.sh b/candle-wasm-examples/yolo/build-lib.sh new file mode 100644 index 00000000..b0ebb182 --- /dev/null +++ b/candle-wasm-examples/yolo/build-lib.sh @@ -0,0 +1,2 @@ +cargo build --target wasm32-unknown-unknown --release +wasm-bindgen ../../target/wasm32-unknown-unknown/release/m.wasm --out-dir build --target web diff --git a/candle-wasm-examples/yolo/src/bin/m.rs b/candle-wasm-examples/yolo/src/bin/m.rs new file mode 100644 index 00000000..e0c8fc5d --- /dev/null +++ b/candle-wasm-examples/yolo/src/bin/m.rs @@ -0,0 +1,25 @@ +use candle_wasm_example_yolo::worker::Model as M; +use wasm_bindgen::prelude::*; + +#[wasm_bindgen] +pub struct Model { + inner: M, +} + +#[wasm_bindgen] +impl Model { + #[wasm_bindgen(constructor)] + pub fn new(data: Vec) -> Result { + let inner = M::load_(&data)?; + Ok(Self { inner }) + } + + #[wasm_bindgen] + pub fn run(&self, image: Vec) -> Result { + let boxes = self.inner.run(image)?; + let json = serde_json::to_string(&boxes)?; + Ok(json) + } +} + +fn main() {} diff --git a/candle-wasm-examples/yolo/src/lib.rs b/candle-wasm-examples/yolo/src/lib.rs index 76af1d63..c691987b 100644 --- a/candle-wasm-examples/yolo/src/lib.rs +++ b/candle-wasm-examples/yolo/src/lib.rs @@ -1,6 +1,6 @@ mod app; mod coco_classes; mod model; -mod worker; +pub mod worker; pub use app::App; pub use worker::Worker; diff --git a/candle-wasm-examples/yolo/src/model.rs b/candle-wasm-examples/yolo/src/model.rs index 184045f0..50fd100c 100644 --- a/candle-wasm-examples/yolo/src/model.rs +++ b/candle-wasm-examples/yolo/src/model.rs @@ -5,8 +5,8 @@ use candle_nn::{ }; use image::DynamicImage; -const CONFIDENCE_THRESHOLD: f32 = 0.5; -const NMS_THRESHOLD: f32 = 0.4; +const CONFIDENCE_THRESHOLD: f32 = 0.25; +const NMS_THRESHOLD: f32 = 0.45; // Model architecture from https://github.com/ultralytics/ultralytics/issues/189 // https://github.com/tinygrad/tinygrad/blob/master/examples/yolov8.py diff --git a/candle-wasm-examples/yolo/src/worker.rs b/candle-wasm-examples/yolo/src/worker.rs index cc029cf8..f6232b3c 100644 --- a/candle-wasm-examples/yolo/src/worker.rs +++ b/candle-wasm-examples/yolo/src/worker.rs @@ -27,46 +27,62 @@ pub struct ModelData { pub weights: Vec, } -struct Model { +pub struct Model { model: YoloV8, } impl Model { - fn run( - &self, - _link: &WorkerLink, - _id: HandlerId, - image_data: Vec, - ) -> Result>> { + pub fn run(&self, image_data: Vec) -> Result>> { console_log!("image data: {}", image_data.len()); let image_data = std::io::Cursor::new(image_data); let original_image = image::io::Reader::new(image_data) .with_guessed_format()? .decode() .map_err(candle::Error::wrap)?; - let image = { - let data = original_image - .resize_exact(640, 640, image::imageops::FilterType::Triangle) - .to_rgb8() - .into_raw(); - Tensor::from_vec(data, (640, 640, 3), &Device::Cpu)?.permute((2, 0, 1))? + let (width, height) = { + let w = original_image.width() as usize; + let h = original_image.height() as usize; + if w < h { + let w = w * 640 / h; + // Sizes have to be divisible by 32. + (w / 32 * 32, 640) + } else { + let h = h * 640 / w; + (640, h / 32 * 32) + } }; - let image = (image.unsqueeze(0)?.to_dtype(DType::F32)? * (1. / 255.))?; - let predictions = self.model.forward(&image)?.squeeze(0)?; + let image_t = { + let img = original_image.resize_exact( + width as u32, + height as u32, + image::imageops::FilterType::CatmullRom, + ); + let data = img.to_rgb8().into_raw(); + Tensor::from_vec( + data, + (img.height() as usize, img.width() as usize, 3), + &Device::Cpu, + )? + .permute((2, 0, 1))? + }; + let image_t = (image_t.unsqueeze(0)?.to_dtype(DType::F32)? * (1. / 255.))?; + let predictions = self.model.forward(&image_t)?.squeeze(0)?; console_log!("generated predictions {predictions:?}"); - let bboxes = report(&predictions, original_image, 640, 640)?; + let bboxes = report(&predictions, original_image, width, height)?; Ok(bboxes) } -} -impl Model { - fn load(md: ModelData) -> Result { + pub fn load_(weights: &[u8]) -> Result { let dev = &Device::Cpu; - let weights = safetensors::tensor::SafeTensors::deserialize(&md.weights)?; + let weights = safetensors::tensor::SafeTensors::deserialize(weights)?; let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, dev); let model = YoloV8::load(vb, Multiples::s(), 80)?; Ok(Self { model }) } + + pub fn load(md: ModelData) -> Result { + Self::load_(&md.weights) + } } pub struct Worker { @@ -112,9 +128,7 @@ impl yew_agent::Worker for Worker { WorkerInput::Run(image_data) => match &mut self.model { None => Err("model has not been set yet".to_string()), Some(model) => { - let result = model - .run(&self.link, id, image_data) - .map_err(|e| e.to_string()); + let result = model.run(image_data).map_err(|e| e.to_string()); Ok(WorkerOutput::ProcessingDone(result)) } },