mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Add SAM UI Demo (#854)
* fix tensor flattening * send image data back * sam ui worker example * SAM example * resize container * no need for this
This commit is contained in:
26
candle-wasm-examples/segment-anything/README.md
Normal file
26
candle-wasm-examples/segment-anything/README.md
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
## Running Segment Anything Example
|
||||||
|
|
||||||
|
Here, we provide two examples of how to run Whisper using a Candle-compiled WASM binary and runtimes.
|
||||||
|
|
||||||
|
### Vanilla JS and WebWorkers
|
||||||
|
|
||||||
|
To build and test the UI made in Vanilla JS and WebWorkers, first we need to build the WASM library:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
sh build-lib.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
This will bundle the library under `./build` and we can import it inside our WebWorker like a normal JS module:
|
||||||
|
|
||||||
|
```js
|
||||||
|
import init, { Model } from "./build/m.js";
|
||||||
|
```
|
||||||
|
|
||||||
|
The full example can be found under `./lib-example.html`. All needed assets are fetched from the web, so no need to download anything.
|
||||||
|
Finally, you can preview the example by running a local HTTP server. For example:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m http.server
|
||||||
|
```
|
||||||
|
|
||||||
|
Then open `http://localhost:8000/lib-example.html` in your browser.
|
407
candle-wasm-examples/segment-anything/lib-example.html
Normal file
407
candle-wasm-examples/segment-anything/lib-example.html
Normal file
@ -0,0 +1,407 @@
|
|||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<meta content="text/html;charset=utf-8" http-equiv="Content-Type" />
|
||||||
|
<title>Candle Segment Anything Model (SAM) Rust/WASM</title>
|
||||||
|
</head>
|
||||||
|
<body></body>
|
||||||
|
</html>
|
||||||
|
|
||||||
|
<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8" />
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||||
|
<style>
|
||||||
|
@import url("https://fonts.googleapis.com/css2?family=Source+Code+Pro:wght@200;300;400&family=Source+Sans+3:wght@100;200;300;400;500;600;700;800;900&display=swap");
|
||||||
|
html,
|
||||||
|
body {
|
||||||
|
font-family: "Source Sans 3", sans-serif;
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
<script src="https://cdn.tailwindcss.com"></script>
|
||||||
|
<script type="module">
|
||||||
|
// base url for audio examples
|
||||||
|
const MODEL_BASEURL =
|
||||||
|
"https://huggingface.co/lmz/candle-sam/resolve/main/";
|
||||||
|
|
||||||
|
// models base url
|
||||||
|
const MODELS = {
|
||||||
|
sam_mobile_tiny: {
|
||||||
|
url: "mobile_sam-tiny-vitt.safetensors",
|
||||||
|
},
|
||||||
|
sam_base: {
|
||||||
|
url: "sam_vit_b_01ec64.safetensors",
|
||||||
|
},
|
||||||
|
};
|
||||||
|
const samWorker = new Worker("./samWorker.js", { type: "module" });
|
||||||
|
|
||||||
|
async function segmentPoints(
|
||||||
|
modelURL, // URL to the weights file
|
||||||
|
modelID, // model ID
|
||||||
|
imageURL, // URL to the audio file
|
||||||
|
points // {x, y} points to prompt image
|
||||||
|
) {
|
||||||
|
return new Promise((resolve, reject) => {
|
||||||
|
function messageHandler(event) {
|
||||||
|
console.log(event.data);
|
||||||
|
if ("status" in event.data) {
|
||||||
|
updateStatus(event.data);
|
||||||
|
}
|
||||||
|
if ("error" in event.data) {
|
||||||
|
samWorker.removeEventListener("message", messageHandler);
|
||||||
|
reject(new Error(event.data.error));
|
||||||
|
}
|
||||||
|
if (event.data.status === "complete-embedding") {
|
||||||
|
samWorker.removeEventListener("message", messageHandler);
|
||||||
|
resolve();
|
||||||
|
}
|
||||||
|
if (event.data.status === "complete") {
|
||||||
|
samWorker.removeEventListener("message", messageHandler);
|
||||||
|
resolve(event.data.output);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
samWorker.addEventListener("message", messageHandler);
|
||||||
|
samWorker.postMessage({
|
||||||
|
modelURL,
|
||||||
|
modelID,
|
||||||
|
imageURL,
|
||||||
|
points,
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
function updateStatus(statusMessage) {
|
||||||
|
statusOutput.innerText = event.data.message;
|
||||||
|
}
|
||||||
|
|
||||||
|
const clearBtn = document.querySelector("#clear-btn");
|
||||||
|
const canvas = document.querySelector("#canvas");
|
||||||
|
const mask = document.querySelector("#mask");
|
||||||
|
const ctxCanvas = canvas.getContext("2d");
|
||||||
|
const ctxMask = mask.getContext("2d");
|
||||||
|
const fileUpload = document.querySelector("#file-upload");
|
||||||
|
const dropArea = document.querySelector("#drop-area");
|
||||||
|
const dropButtons = document.querySelector("#drop-buttons");
|
||||||
|
const imagesExamples = document.querySelector("#image-select");
|
||||||
|
const modelSelection = document.querySelector("#model");
|
||||||
|
const statusOutput = document.querySelector("#output-status");
|
||||||
|
|
||||||
|
//add event listener to file input
|
||||||
|
fileUpload.addEventListener("change", (e) => {
|
||||||
|
const target = e.target;
|
||||||
|
if (target.files.length > 0) {
|
||||||
|
const href = URL.createObjectURL(target.files[0]);
|
||||||
|
cleanImageCanvas();
|
||||||
|
drawImageCanvas(href);
|
||||||
|
setImageEmbeddings(href);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
// add event listener to drop-area
|
||||||
|
dropArea.addEventListener("dragenter", (e) => {
|
||||||
|
e.preventDefault();
|
||||||
|
dropArea.classList.add("border-blue-700");
|
||||||
|
});
|
||||||
|
dropArea.addEventListener("dragleave", (e) => {
|
||||||
|
e.preventDefault();
|
||||||
|
dropArea.classList.remove("border-blue-700");
|
||||||
|
});
|
||||||
|
dropArea.addEventListener("dragover", (e) => {
|
||||||
|
e.preventDefault();
|
||||||
|
});
|
||||||
|
dropArea.addEventListener("drop", (e) => {
|
||||||
|
e.preventDefault();
|
||||||
|
dropArea.classList.remove("border-blue-700");
|
||||||
|
const url = e.dataTransfer.getData("text/uri-list");
|
||||||
|
const files = e.dataTransfer.files;
|
||||||
|
|
||||||
|
if (files.length > 0) {
|
||||||
|
const href = URL.createObjectURL(files[0]);
|
||||||
|
cleanImageCanvas();
|
||||||
|
drawImageCanvas(href);
|
||||||
|
setImageEmbeddings(href);
|
||||||
|
} else if (url) {
|
||||||
|
cleanImageCanvas();
|
||||||
|
drawImageCanvas(url);
|
||||||
|
setImageEmbeddings(url);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let hasImage = false;
|
||||||
|
let isSegmenting = false;
|
||||||
|
let isEmbedding = false;
|
||||||
|
let currentImageURL = "";
|
||||||
|
//add event listener to image examples
|
||||||
|
imagesExamples.addEventListener("click", (e) => {
|
||||||
|
if (isEmbedding || isSegmenting) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const target = e.target;
|
||||||
|
if (target.nodeName === "IMG") {
|
||||||
|
const href = target.src;
|
||||||
|
cleanImageCanvas();
|
||||||
|
drawImageCanvas(href);
|
||||||
|
setImageEmbeddings(href);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
//add event listener to clear button
|
||||||
|
clearBtn.addEventListener("click", () => {
|
||||||
|
cleanImageCanvas();
|
||||||
|
});
|
||||||
|
//add click event to canvas
|
||||||
|
canvas.addEventListener("click", async (event) => {
|
||||||
|
if (!hasImage || isEmbedding || isSegmenting) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const targetBox = event.target.getBoundingClientRect();
|
||||||
|
const x = (event.clientX - targetBox.left) / targetBox.width;
|
||||||
|
const y = (event.clientY - targetBox.top) / targetBox.height;
|
||||||
|
isSegmenting = true;
|
||||||
|
const { maskURL } = await getSegmentationMask({ x, y });
|
||||||
|
isSegmenting = false;
|
||||||
|
drawMask(maskURL);
|
||||||
|
});
|
||||||
|
|
||||||
|
async function getSegmentationMask(points) {
|
||||||
|
const modelID = modelSelection.value;
|
||||||
|
const modelURL = MODEL_BASEURL + MODELS[modelID].url;
|
||||||
|
const imageURL = currentImageURL;
|
||||||
|
const { maskURL } = await segmentPoints(
|
||||||
|
modelURL,
|
||||||
|
modelID,
|
||||||
|
imageURL,
|
||||||
|
points
|
||||||
|
);
|
||||||
|
return { maskURL };
|
||||||
|
}
|
||||||
|
async function setImageEmbeddings(imageURL) {
|
||||||
|
if (isEmbedding) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
canvas.classList.remove("cursor-pointer");
|
||||||
|
canvas.classList.add("cursor-wait");
|
||||||
|
clearBtn.disabled = true;
|
||||||
|
const modelID = modelSelection.value;
|
||||||
|
const modelURL = MODEL_BASEURL + MODELS[modelID].url;
|
||||||
|
isEmbedding = true;
|
||||||
|
await segmentPoints(modelURL, modelID, imageURL);
|
||||||
|
canvas.classList.remove("cursor-wait");
|
||||||
|
canvas.classList.add("cursor-pointer");
|
||||||
|
clearBtn.disabled = false;
|
||||||
|
isEmbedding = false;
|
||||||
|
currentImageURL = imageURL;
|
||||||
|
}
|
||||||
|
|
||||||
|
function cleanImageCanvas() {
|
||||||
|
ctxCanvas.clearRect(0, 0, canvas.width, canvas.height);
|
||||||
|
ctxMask.clearRect(0, 0, canvas.width, canvas.height);
|
||||||
|
hasImage = false;
|
||||||
|
isEmbedding = false;
|
||||||
|
isSegmenting = false;
|
||||||
|
currentImageURL = "";
|
||||||
|
clearBtn.classList.add("invisible");
|
||||||
|
canvas.parentElement.style.height = "auto";
|
||||||
|
dropButtons.classList.remove("invisible");
|
||||||
|
}
|
||||||
|
function drawMask(maskURL) {
|
||||||
|
if (!maskURL) {
|
||||||
|
throw new Error("No mask URL provided");
|
||||||
|
}
|
||||||
|
|
||||||
|
const img = new Image();
|
||||||
|
img.crossOrigin = "anonymous";
|
||||||
|
|
||||||
|
img.onload = () => {
|
||||||
|
mask.width = canvas.width;
|
||||||
|
mask.height = canvas.height;
|
||||||
|
ctxMask.drawImage(canvas, 0, 0);
|
||||||
|
ctxMask.globalCompositeOperation = "source-atop";
|
||||||
|
ctxMask.fillStyle = "rgba(255, 0, 0, 0.6)";
|
||||||
|
ctxMask.fillRect(0, 0, canvas.width, canvas.height);
|
||||||
|
ctxMask.globalCompositeOperation = "destination-in";
|
||||||
|
ctxMask.drawImage(img, 0, 0);
|
||||||
|
};
|
||||||
|
img.src = maskURL;
|
||||||
|
}
|
||||||
|
function drawImageCanvas(imgURL) {
|
||||||
|
if (!imgURL) {
|
||||||
|
throw new Error("No image URL provided");
|
||||||
|
}
|
||||||
|
|
||||||
|
ctxCanvas.clearRect(0, 0, canvas.width, canvas.height);
|
||||||
|
ctxCanvas.clearRect(0, 0, canvas.width, canvas.height);
|
||||||
|
|
||||||
|
const img = new Image();
|
||||||
|
img.crossOrigin = "anonymous";
|
||||||
|
|
||||||
|
img.onload = () => {
|
||||||
|
canvas.width = img.width;
|
||||||
|
canvas.height = img.height;
|
||||||
|
ctxCanvas.drawImage(img, 0, 0);
|
||||||
|
canvas.parentElement.style.height = canvas.offsetHeight + "px";
|
||||||
|
hasImage = true;
|
||||||
|
clearBtn.classList.remove("invisible");
|
||||||
|
dropButtons.classList.add("invisible");
|
||||||
|
};
|
||||||
|
img.src = imgURL;
|
||||||
|
}
|
||||||
|
|
||||||
|
const observer = new ResizeObserver((entries) => {
|
||||||
|
for (let entry of entries) {
|
||||||
|
if (entry.target === canvas) {
|
||||||
|
canvas.parentElement.style.height = canvas.offsetHeight + "px";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
observer.observe(canvas);
|
||||||
|
</script>
|
||||||
|
</head>
|
||||||
|
<body class="container max-w-4xl mx-auto p-4">
|
||||||
|
<main class="grid grid-cols-1 gap-8 relative">
|
||||||
|
<span class="absolute text-5xl -ml-[1em]">🕯️</span>
|
||||||
|
<div>
|
||||||
|
<h1 class="text-5xl font-bold">Candle Segment Anything</h1>
|
||||||
|
<h2 class="text-2xl font-bold">Rust/WASM Demo</h2>
|
||||||
|
<p class="max-w-lg">
|
||||||
|
Zero-shot image segmentation with
|
||||||
|
<a
|
||||||
|
href="https://segment-anything.com"
|
||||||
|
class="underline hover:text-blue-500 hover:no-underline"
|
||||||
|
target="_blank"
|
||||||
|
>Segment Anything Model (SAM)</a
|
||||||
|
>
|
||||||
|
and
|
||||||
|
<a
|
||||||
|
href="https://github.com/ChaoningZhang/MobileSAM"
|
||||||
|
class="underline hover:text-blue-500 hover:no-underline"
|
||||||
|
target="_blank"
|
||||||
|
>MobileSAM </a
|
||||||
|
>. It runs in the browser with a WASM runtime built with
|
||||||
|
<a
|
||||||
|
href="https://github.com/huggingface/candle/"
|
||||||
|
target="_blank"
|
||||||
|
class="underline hover:text-blue-500 hover:no-underline"
|
||||||
|
>Candle
|
||||||
|
</a>
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
<label for="model" class="font-medium">Models Options: </label>
|
||||||
|
<select
|
||||||
|
id="model"
|
||||||
|
class="border-2 border-gray-500 rounded-md font-light"
|
||||||
|
>
|
||||||
|
<option value="sam_mobile_tiny" selected>
|
||||||
|
Mobile SAM Tiny (40.6 MB)
|
||||||
|
</option>
|
||||||
|
<option value="sam_base">SAM Base (375 MB)</option>
|
||||||
|
</select>
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
<p class="text-xs italic max-w-lg">
|
||||||
|
<b>Note:</b>
|
||||||
|
The model's first run may take a few seconds as it loads and caches
|
||||||
|
the model in the browser, and then creates the image embeddings. Any
|
||||||
|
subsequent clicks on points will be significantly faster.
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
<div class="relative max-w-lg">
|
||||||
|
<div class="flex justify-between items-center">
|
||||||
|
<div class="px-2 rounded-md inline text-xs">
|
||||||
|
<span id="output-status" class="m-auto font-light"></span>
|
||||||
|
</div>
|
||||||
|
<button
|
||||||
|
id="clear-btn"
|
||||||
|
class="text-xs bg-white rounded-md disabled:opacity-50 flex gap-1 items-center invisible"
|
||||||
|
>
|
||||||
|
<svg
|
||||||
|
class=""
|
||||||
|
xmlns="http://www.w3.org/2000/svg"
|
||||||
|
viewBox="0 0 13 12"
|
||||||
|
height="1em"
|
||||||
|
>
|
||||||
|
<path
|
||||||
|
d="M1.6.7 12 11.1M12 .7 1.6 11.1"
|
||||||
|
stroke="#2E3036"
|
||||||
|
stroke-width="2"
|
||||||
|
/>
|
||||||
|
</svg>
|
||||||
|
Clear image
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
<div
|
||||||
|
id="drop-area"
|
||||||
|
class="flex flex-col items-center justify-center border-2 border-gray-300 border-dashed rounded-xl relative p-20 w-full overflow-hidden"
|
||||||
|
>
|
||||||
|
<div
|
||||||
|
id="drop-buttons"
|
||||||
|
class="flex flex-col items-center justify-center space-y-1 text-center relative z-10"
|
||||||
|
>
|
||||||
|
<svg
|
||||||
|
width="25"
|
||||||
|
height="25"
|
||||||
|
viewBox="0 0 25 25"
|
||||||
|
fill="none"
|
||||||
|
xmlns="http://www.w3.org/2000/svg"
|
||||||
|
>
|
||||||
|
<path
|
||||||
|
d="M3.5 24.3a3 3 0 0 1-1.9-.8c-.5-.5-.8-1.2-.8-1.9V2.9c0-.7.3-1.3.8-1.9.6-.5 1.2-.7 2-.7h18.6c.7 0 1.3.2 1.9.7.5.6.7 1.2.7 2v18.6c0 .7-.2 1.4-.7 1.9a3 3 0 0 1-2 .8H3.6Zm0-2.7h18.7V2.9H3.5v18.7Zm2.7-2.7h13.3c.3 0 .5 0 .6-.3v-.7l-3.7-5a.6.6 0 0 0-.6-.2c-.2 0-.4 0-.5.3l-3.5 4.6-2.4-3.3a.6.6 0 0 0-.6-.3c-.2 0-.4.1-.5.3l-2.7 3.6c-.1.2-.2.4 0 .7.1.2.3.3.6.3Z"
|
||||||
|
fill="#000"
|
||||||
|
/>
|
||||||
|
</svg>
|
||||||
|
<div class="flex text-sm text-gray-600">
|
||||||
|
<label
|
||||||
|
for="file-upload"
|
||||||
|
class="relative cursor-pointer bg-white rounded-md font-medium text-blue-950 hover:text-blue-700"
|
||||||
|
>
|
||||||
|
<span>Drag and drop your image here</span>
|
||||||
|
<span class="block text-xs">or</span>
|
||||||
|
<span class="block text-xs">Click to upload</span>
|
||||||
|
</label>
|
||||||
|
</div>
|
||||||
|
<input
|
||||||
|
id="file-upload"
|
||||||
|
name="file-upload"
|
||||||
|
type="file"
|
||||||
|
class="sr-only"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
<canvas id="canvas" class="absolute w-full"></canvas>
|
||||||
|
<canvas
|
||||||
|
id="mask"
|
||||||
|
class="pointer-events-none absolute w-full"
|
||||||
|
></canvas>
|
||||||
|
</div>
|
||||||
|
<div class="text-right py-2">
|
||||||
|
<button
|
||||||
|
id="share-btn"
|
||||||
|
class="bg-white rounded-md hover:outline outline-orange-200 disabled:opacity-50 invisible"
|
||||||
|
>
|
||||||
|
<img
|
||||||
|
src="https://huggingface.co/datasets/huggingface/badges/raw/main/share-to-community-sm.svg"
|
||||||
|
/>
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
<div
|
||||||
|
class="flex gap-3 items-center overflow-x-scroll"
|
||||||
|
id="image-select"
|
||||||
|
>
|
||||||
|
<h3 class="font-medium">Examples:</h3>
|
||||||
|
|
||||||
|
<img
|
||||||
|
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/candle/examples/sf.jpg"
|
||||||
|
class="cursor-pointer w-24 h-24 object-cover"
|
||||||
|
/>
|
||||||
|
<img
|
||||||
|
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/candle/examples/bike.jpeg"
|
||||||
|
class="cursor-pointer w-24 h-24 object-cover"
|
||||||
|
/>
|
||||||
|
<img
|
||||||
|
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/candle/examples/000000000077.jpg"
|
||||||
|
class="cursor-pointer w-24 h-24 object-cover"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</main>
|
||||||
|
</body>
|
||||||
|
</html>
|
156
candle-wasm-examples/segment-anything/samWorker.js
Normal file
156
candle-wasm-examples/segment-anything/samWorker.js
Normal file
@ -0,0 +1,156 @@
|
|||||||
|
//load the candle SAM Model wasm module
|
||||||
|
import init, { Model } from "./build/m.js";
|
||||||
|
|
||||||
|
async function fetchArrayBuffer(url, cacheModel = true) {
|
||||||
|
if (!cacheModel)
|
||||||
|
return new Uint8Array(await (await fetch(url)).arrayBuffer());
|
||||||
|
const cacheName = "sam-candle-cache";
|
||||||
|
const cache = await caches.open(cacheName);
|
||||||
|
const cachedResponse = await cache.match(url);
|
||||||
|
if (cachedResponse) {
|
||||||
|
const data = await cachedResponse.arrayBuffer();
|
||||||
|
return new Uint8Array(data);
|
||||||
|
}
|
||||||
|
const res = await fetch(url, { cache: "force-cache" });
|
||||||
|
cache.put(url, res.clone());
|
||||||
|
return new Uint8Array(await res.arrayBuffer());
|
||||||
|
}
|
||||||
|
class SAMModel {
|
||||||
|
static instance = {};
|
||||||
|
// keep current image embeddings state
|
||||||
|
static imageArrayHash = {};
|
||||||
|
// Add a new property to hold the current modelID
|
||||||
|
static currentModelID = null;
|
||||||
|
|
||||||
|
static async getInstance(modelURL, modelID) {
|
||||||
|
if (!this.instance[modelID]) {
|
||||||
|
await init();
|
||||||
|
|
||||||
|
self.postMessage({
|
||||||
|
status: "loading",
|
||||||
|
message: `Loading Model ${modelID}`,
|
||||||
|
});
|
||||||
|
const weightsArrayU8 = await fetchArrayBuffer(modelURL);
|
||||||
|
this.instance[modelID] = new Model(
|
||||||
|
weightsArrayU8,
|
||||||
|
/tiny|mobile/.test(modelID)
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
self.postMessage({ status: "loading", message: "Model Already Loaded" });
|
||||||
|
}
|
||||||
|
// Set the current modelID to the modelID that was passed in
|
||||||
|
this.currentModelID = modelID;
|
||||||
|
return this.instance[modelID];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove the modelID parameter from setImageEmbeddings
|
||||||
|
static setImageEmbeddings(imageArrayU8) {
|
||||||
|
// check if image embeddings are already set for this image and model
|
||||||
|
const imageArrayHash = this.getSimpleHash(imageArrayU8);
|
||||||
|
if (
|
||||||
|
this.imageArrayHash[this.currentModelID] === imageArrayHash &&
|
||||||
|
this.instance[this.currentModelID]
|
||||||
|
) {
|
||||||
|
self.postMessage({
|
||||||
|
status: "embedding",
|
||||||
|
message: "Embeddings Already Set",
|
||||||
|
});
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
this.imageArrayHash[this.currentModelID] = imageArrayHash;
|
||||||
|
this.instance[this.currentModelID].set_image_embeddings(imageArrayU8);
|
||||||
|
self.postMessage({ status: "embedding", message: "Embeddings Set" });
|
||||||
|
}
|
||||||
|
|
||||||
|
static getSimpleHash(imageArrayU8) {
|
||||||
|
// get simple hash of imageArrayU8
|
||||||
|
let imageArrayHash = 0;
|
||||||
|
for (let i = 0; i < imageArrayU8.length; i += 100) {
|
||||||
|
imageArrayHash ^= imageArrayU8[i];
|
||||||
|
}
|
||||||
|
return imageArrayHash.toString(16);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async function createImageCanvas(
|
||||||
|
{ mask_shape, mask_data }, // mask
|
||||||
|
{ original_width, original_height, width, height } // original image
|
||||||
|
) {
|
||||||
|
const [_, __, shape_width, shape_height] = mask_shape;
|
||||||
|
const maskCanvas = new OffscreenCanvas(shape_width, shape_height); // canvas for mask
|
||||||
|
const maskCtx = maskCanvas.getContext("2d");
|
||||||
|
const canvas = new OffscreenCanvas(original_width, original_height); // canvas for creating mask with original image size
|
||||||
|
const ctx = canvas.getContext("2d");
|
||||||
|
|
||||||
|
const imageData = maskCtx.createImageData(
|
||||||
|
maskCanvas.width,
|
||||||
|
maskCanvas.height
|
||||||
|
);
|
||||||
|
const data = imageData.data;
|
||||||
|
|
||||||
|
for (let p = 0; p < data.length; p += 4) {
|
||||||
|
data[p] = 0;
|
||||||
|
data[p + 1] = 0;
|
||||||
|
data[p + 2] = 0;
|
||||||
|
data[p + 3] = mask_data[p / 4] * 255;
|
||||||
|
}
|
||||||
|
maskCtx.putImageData(imageData, 0, 0);
|
||||||
|
|
||||||
|
let sx, sy;
|
||||||
|
if (original_height < original_width) {
|
||||||
|
sy = original_height / original_width;
|
||||||
|
sx = 1;
|
||||||
|
} else {
|
||||||
|
sy = 1;
|
||||||
|
sx = original_width / original_height;
|
||||||
|
}
|
||||||
|
ctx.drawImage(
|
||||||
|
maskCanvas,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
maskCanvas.width * sx,
|
||||||
|
maskCanvas.height * sy,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
original_width,
|
||||||
|
original_height
|
||||||
|
);
|
||||||
|
|
||||||
|
const blob = await canvas.convertToBlob();
|
||||||
|
return URL.createObjectURL(blob);
|
||||||
|
}
|
||||||
|
|
||||||
|
self.addEventListener("message", async (event) => {
|
||||||
|
const { modelURL, modelID, imageURL, points } = event.data;
|
||||||
|
try {
|
||||||
|
self.postMessage({ status: "loading", message: "Starting SAM" });
|
||||||
|
const sam = await SAMModel.getInstance(modelURL, modelID);
|
||||||
|
|
||||||
|
self.postMessage({ status: "loading", message: "Loading Image" });
|
||||||
|
const imageArrayU8 = await fetchArrayBuffer(imageURL, false);
|
||||||
|
|
||||||
|
self.postMessage({ status: "embedding", message: "Creating Embeddings" });
|
||||||
|
SAMModel.setImageEmbeddings(imageArrayU8);
|
||||||
|
if (!points) {
|
||||||
|
// no points only do the embeddings
|
||||||
|
self.postMessage({
|
||||||
|
status: "complete-embedding",
|
||||||
|
message: "Embeddings Complete",
|
||||||
|
});
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
self.postMessage({ status: "segmenting", message: "Segmenting" });
|
||||||
|
const result = sam.mask_for_point(points.x, points.y);
|
||||||
|
const { mask, image } = JSON.parse(result);
|
||||||
|
const maskDataURL = await createImageCanvas(mask, image);
|
||||||
|
// Send the segment back to the main thread as JSON
|
||||||
|
self.postMessage({
|
||||||
|
status: "complete",
|
||||||
|
message: "Segmentation Complete",
|
||||||
|
output: { maskURL: maskDataURL },
|
||||||
|
});
|
||||||
|
} catch (e) {
|
||||||
|
self.postMessage({ error: e });
|
||||||
|
}
|
||||||
|
});
|
@ -98,7 +98,7 @@ impl Model {
|
|||||||
Some((x, y)),
|
Some((x, y)),
|
||||||
false,
|
false,
|
||||||
)?;
|
)?;
|
||||||
let iou = iou_predictions.to_vec1::<f32>()?[0];
|
let iou = iou_predictions.flatten(0, 1)?.to_vec1::<f32>()?[0];
|
||||||
let mask_shape = mask.dims().to_vec();
|
let mask_shape = mask.dims().to_vec();
|
||||||
let mask_data = mask.ge(0f32)?.flatten_all()?.to_vec1::<u8>()?;
|
let mask_data = mask.ge(0f32)?.flatten_all()?.to_vec1::<u8>()?;
|
||||||
let mask = Mask {
|
let mask = Mask {
|
||||||
@ -106,7 +106,13 @@ impl Model {
|
|||||||
mask_shape,
|
mask_shape,
|
||||||
mask_data,
|
mask_data,
|
||||||
};
|
};
|
||||||
let json = serde_json::to_string(&mask)?;
|
let image = Image {
|
||||||
|
original_width: embeddings.original_width,
|
||||||
|
original_height: embeddings.original_height,
|
||||||
|
width: embeddings.width,
|
||||||
|
height: embeddings.height,
|
||||||
|
};
|
||||||
|
let json = serde_json::to_string(&MaskImage { mask, image })?;
|
||||||
Ok(json)
|
Ok(json)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -117,6 +123,18 @@ struct Mask {
|
|||||||
mask_shape: Vec<usize>,
|
mask_shape: Vec<usize>,
|
||||||
mask_data: Vec<u8>,
|
mask_data: Vec<u8>,
|
||||||
}
|
}
|
||||||
|
#[derive(serde::Serialize, serde::Deserialize)]
|
||||||
|
struct Image {
|
||||||
|
original_width: u32,
|
||||||
|
original_height: u32,
|
||||||
|
width: u32,
|
||||||
|
height: u32,
|
||||||
|
}
|
||||||
|
#[derive(serde::Serialize, serde::Deserialize)]
|
||||||
|
struct MaskImage {
|
||||||
|
mask: Mask,
|
||||||
|
image: Image,
|
||||||
|
}
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
console_error_panic_hook::set_once();
|
console_error_panic_hook::set_once();
|
||||||
|
Reference in New Issue
Block a user