From 39157346cb96d7610ddb6cdc686aed32d12bba3d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Radam=C3=A9s=20Ajna?= Date: Thu, 14 Sep 2023 22:31:58 -0700 Subject: [PATCH] Add SAM UI Demo (#854) * fix tensor flattening * send image data back * sam ui worker example * SAM example * resize container * no need for this --- .../segment-anything/README.md | 26 ++ .../segment-anything/lib-example.html | 407 ++++++++++++++++++ .../segment-anything/samWorker.js | 156 +++++++ .../segment-anything/src/bin/m.rs | 22 +- 4 files changed, 609 insertions(+), 2 deletions(-) create mode 100644 candle-wasm-examples/segment-anything/README.md create mode 100644 candle-wasm-examples/segment-anything/lib-example.html create mode 100644 candle-wasm-examples/segment-anything/samWorker.js diff --git a/candle-wasm-examples/segment-anything/README.md b/candle-wasm-examples/segment-anything/README.md new file mode 100644 index 00000000..04ff2033 --- /dev/null +++ b/candle-wasm-examples/segment-anything/README.md @@ -0,0 +1,26 @@ +## Running Segment Anything Example + +Here, we provide two examples of how to run Whisper using a Candle-compiled WASM binary and runtimes. + +### Vanilla JS and WebWorkers + +To build and test the UI made in Vanilla JS and WebWorkers, first we need to build the WASM library: + +```bash +sh build-lib.sh +``` + +This will bundle the library under `./build` and we can import it inside our WebWorker like a normal JS module: + +```js +import init, { Model } from "./build/m.js"; +``` + +The full example can be found under `./lib-example.html`. All needed assets are fetched from the web, so no need to download anything. +Finally, you can preview the example by running a local HTTP server. For example: + +```bash +python -m http.server +``` + +Then open `http://localhost:8000/lib-example.html` in your browser. diff --git a/candle-wasm-examples/segment-anything/lib-example.html b/candle-wasm-examples/segment-anything/lib-example.html new file mode 100644 index 00000000..127b9152 --- /dev/null +++ b/candle-wasm-examples/segment-anything/lib-example.html @@ -0,0 +1,407 @@ + + + + Candle Segment Anything Model (SAM) Rust/WASM + + + + + + + + + + + + + + +
+ 🕯️ +
+

Candle Segment Anything

+

Rust/WASM Demo

+

+ Zero-shot image segmentation with + Segment Anything Model (SAM) + and + MobileSAM . It runs in the browser with a WASM runtime built with + Candle + +

+
+
+ + +
+
+

+ Note: + The model's first run may take a few seconds as it loads and caches + the model in the browser, and then creates the image embeddings. Any + subsequent clicks on points will be significantly faster. +

+
+
+
+
+ +
+ +
+
+
+ + + +
+ +
+ +
+ + +
+
+ +
+
+
+
+

Examples:

+ + + + +
+
+
+ + diff --git a/candle-wasm-examples/segment-anything/samWorker.js b/candle-wasm-examples/segment-anything/samWorker.js new file mode 100644 index 00000000..b90498de --- /dev/null +++ b/candle-wasm-examples/segment-anything/samWorker.js @@ -0,0 +1,156 @@ +//load the candle SAM Model wasm module +import init, { Model } from "./build/m.js"; + +async function fetchArrayBuffer(url, cacheModel = true) { + if (!cacheModel) + return new Uint8Array(await (await fetch(url)).arrayBuffer()); + const cacheName = "sam-candle-cache"; + const cache = await caches.open(cacheName); + const cachedResponse = await cache.match(url); + if (cachedResponse) { + const data = await cachedResponse.arrayBuffer(); + return new Uint8Array(data); + } + const res = await fetch(url, { cache: "force-cache" }); + cache.put(url, res.clone()); + return new Uint8Array(await res.arrayBuffer()); +} +class SAMModel { + static instance = {}; + // keep current image embeddings state + static imageArrayHash = {}; + // Add a new property to hold the current modelID + static currentModelID = null; + + static async getInstance(modelURL, modelID) { + if (!this.instance[modelID]) { + await init(); + + self.postMessage({ + status: "loading", + message: `Loading Model ${modelID}`, + }); + const weightsArrayU8 = await fetchArrayBuffer(modelURL); + this.instance[modelID] = new Model( + weightsArrayU8, + /tiny|mobile/.test(modelID) + ); + } else { + self.postMessage({ status: "loading", message: "Model Already Loaded" }); + } + // Set the current modelID to the modelID that was passed in + this.currentModelID = modelID; + return this.instance[modelID]; + } + + // Remove the modelID parameter from setImageEmbeddings + static setImageEmbeddings(imageArrayU8) { + // check if image embeddings are already set for this image and model + const imageArrayHash = this.getSimpleHash(imageArrayU8); + if ( + this.imageArrayHash[this.currentModelID] === imageArrayHash && + this.instance[this.currentModelID] + ) { + self.postMessage({ + status: "embedding", + message: "Embeddings Already Set", + }); + return; + } + this.imageArrayHash[this.currentModelID] = imageArrayHash; + this.instance[this.currentModelID].set_image_embeddings(imageArrayU8); + self.postMessage({ status: "embedding", message: "Embeddings Set" }); + } + + static getSimpleHash(imageArrayU8) { + // get simple hash of imageArrayU8 + let imageArrayHash = 0; + for (let i = 0; i < imageArrayU8.length; i += 100) { + imageArrayHash ^= imageArrayU8[i]; + } + return imageArrayHash.toString(16); + } +} + +async function createImageCanvas( + { mask_shape, mask_data }, // mask + { original_width, original_height, width, height } // original image +) { + const [_, __, shape_width, shape_height] = mask_shape; + const maskCanvas = new OffscreenCanvas(shape_width, shape_height); // canvas for mask + const maskCtx = maskCanvas.getContext("2d"); + const canvas = new OffscreenCanvas(original_width, original_height); // canvas for creating mask with original image size + const ctx = canvas.getContext("2d"); + + const imageData = maskCtx.createImageData( + maskCanvas.width, + maskCanvas.height + ); + const data = imageData.data; + + for (let p = 0; p < data.length; p += 4) { + data[p] = 0; + data[p + 1] = 0; + data[p + 2] = 0; + data[p + 3] = mask_data[p / 4] * 255; + } + maskCtx.putImageData(imageData, 0, 0); + + let sx, sy; + if (original_height < original_width) { + sy = original_height / original_width; + sx = 1; + } else { + sy = 1; + sx = original_width / original_height; + } + ctx.drawImage( + maskCanvas, + 0, + 0, + maskCanvas.width * sx, + maskCanvas.height * sy, + 0, + 0, + original_width, + original_height + ); + + const blob = await canvas.convertToBlob(); + return URL.createObjectURL(blob); +} + +self.addEventListener("message", async (event) => { + const { modelURL, modelID, imageURL, points } = event.data; + try { + self.postMessage({ status: "loading", message: "Starting SAM" }); + const sam = await SAMModel.getInstance(modelURL, modelID); + + self.postMessage({ status: "loading", message: "Loading Image" }); + const imageArrayU8 = await fetchArrayBuffer(imageURL, false); + + self.postMessage({ status: "embedding", message: "Creating Embeddings" }); + SAMModel.setImageEmbeddings(imageArrayU8); + if (!points) { + // no points only do the embeddings + self.postMessage({ + status: "complete-embedding", + message: "Embeddings Complete", + }); + return; + } + + self.postMessage({ status: "segmenting", message: "Segmenting" }); + const result = sam.mask_for_point(points.x, points.y); + const { mask, image } = JSON.parse(result); + const maskDataURL = await createImageCanvas(mask, image); + // Send the segment back to the main thread as JSON + self.postMessage({ + status: "complete", + message: "Segmentation Complete", + output: { maskURL: maskDataURL }, + }); + } catch (e) { + self.postMessage({ error: e }); + } +}); diff --git a/candle-wasm-examples/segment-anything/src/bin/m.rs b/candle-wasm-examples/segment-anything/src/bin/m.rs index b53f5b9b..949c18a0 100644 --- a/candle-wasm-examples/segment-anything/src/bin/m.rs +++ b/candle-wasm-examples/segment-anything/src/bin/m.rs @@ -98,7 +98,7 @@ impl Model { Some((x, y)), false, )?; - let iou = iou_predictions.to_vec1::()?[0]; + let iou = iou_predictions.flatten(0, 1)?.to_vec1::()?[0]; let mask_shape = mask.dims().to_vec(); let mask_data = mask.ge(0f32)?.flatten_all()?.to_vec1::()?; let mask = Mask { @@ -106,7 +106,13 @@ impl Model { mask_shape, mask_data, }; - let json = serde_json::to_string(&mask)?; + let image = Image { + original_width: embeddings.original_width, + original_height: embeddings.original_height, + width: embeddings.width, + height: embeddings.height, + }; + let json = serde_json::to_string(&MaskImage { mask, image })?; Ok(json) } } @@ -117,6 +123,18 @@ struct Mask { mask_shape: Vec, mask_data: Vec, } +#[derive(serde::Serialize, serde::Deserialize)] +struct Image { + original_width: u32, + original_height: u32, + width: u32, + height: u32, +} +#[derive(serde::Serialize, serde::Deserialize)] +struct MaskImage { + mask: Mask, + image: Image, +} fn main() { console_error_panic_hook::set_once();