mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00

* [segment-anything] add multi point logic for demo site * [segment-anything] remove libs and update functions
156 lines
4.7 KiB
JavaScript
156 lines
4.7 KiB
JavaScript
//load the candle SAM Model wasm module
|
|
import init, { Model } from "./build/m.js";
|
|
|
|
async function fetchArrayBuffer(url, cacheModel = true) {
|
|
if (!cacheModel)
|
|
return new Uint8Array(await (await fetch(url)).arrayBuffer());
|
|
const cacheName = "sam-candle-cache";
|
|
const cache = await caches.open(cacheName);
|
|
const cachedResponse = await cache.match(url);
|
|
if (cachedResponse) {
|
|
const data = await cachedResponse.arrayBuffer();
|
|
return new Uint8Array(data);
|
|
}
|
|
const res = await fetch(url, { cache: "force-cache" });
|
|
cache.put(url, res.clone());
|
|
return new Uint8Array(await res.arrayBuffer());
|
|
}
|
|
class SAMModel {
|
|
static instance = {};
|
|
// keep current image embeddings state
|
|
static imageArrayHash = {};
|
|
// Add a new property to hold the current modelID
|
|
static currentModelID = null;
|
|
|
|
static async getInstance(modelURL, modelID) {
|
|
if (!this.instance[modelID]) {
|
|
await init();
|
|
|
|
self.postMessage({
|
|
status: "loading",
|
|
message: `Loading Model ${modelID}`,
|
|
});
|
|
const weightsArrayU8 = await fetchArrayBuffer(modelURL);
|
|
this.instance[modelID] = new Model(
|
|
weightsArrayU8,
|
|
/tiny|mobile/.test(modelID)
|
|
);
|
|
} else {
|
|
self.postMessage({ status: "loading", message: "Model Already Loaded" });
|
|
}
|
|
// Set the current modelID to the modelID that was passed in
|
|
this.currentModelID = modelID;
|
|
return this.instance[modelID];
|
|
}
|
|
|
|
// Remove the modelID parameter from setImageEmbeddings
|
|
static setImageEmbeddings(imageArrayU8) {
|
|
// check if image embeddings are already set for this image and model
|
|
const imageArrayHash = this.getSimpleHash(imageArrayU8);
|
|
if (
|
|
this.imageArrayHash[this.currentModelID] === imageArrayHash &&
|
|
this.instance[this.currentModelID]
|
|
) {
|
|
self.postMessage({
|
|
status: "embedding",
|
|
message: "Embeddings Already Set",
|
|
});
|
|
return;
|
|
}
|
|
this.imageArrayHash[this.currentModelID] = imageArrayHash;
|
|
this.instance[this.currentModelID].set_image_embeddings(imageArrayU8);
|
|
self.postMessage({ status: "embedding", message: "Embeddings Set" });
|
|
}
|
|
|
|
static getSimpleHash(imageArrayU8) {
|
|
// get simple hash of imageArrayU8
|
|
let imageArrayHash = 0;
|
|
for (let i = 0; i < imageArrayU8.length; i += 100) {
|
|
imageArrayHash ^= imageArrayU8[i];
|
|
}
|
|
return imageArrayHash.toString(16);
|
|
}
|
|
}
|
|
|
|
async function createImageCanvas(
|
|
{ mask_shape, mask_data }, // mask
|
|
{ original_width, original_height, width, height } // original image
|
|
) {
|
|
const [_, __, shape_width, shape_height] = mask_shape;
|
|
const maskCanvas = new OffscreenCanvas(shape_width, shape_height); // canvas for mask
|
|
const maskCtx = maskCanvas.getContext("2d");
|
|
const canvas = new OffscreenCanvas(original_width, original_height); // canvas for creating mask with original image size
|
|
const ctx = canvas.getContext("2d");
|
|
|
|
const imageData = maskCtx.createImageData(
|
|
maskCanvas.width,
|
|
maskCanvas.height
|
|
);
|
|
const data = imageData.data;
|
|
|
|
for (let p = 0; p < data.length; p += 4) {
|
|
data[p] = 0;
|
|
data[p + 1] = 0;
|
|
data[p + 2] = 0;
|
|
data[p + 3] = mask_data[p / 4] * 255;
|
|
}
|
|
maskCtx.putImageData(imageData, 0, 0);
|
|
|
|
let sx, sy;
|
|
if (original_height < original_width) {
|
|
sy = original_height / original_width;
|
|
sx = 1;
|
|
} else {
|
|
sy = 1;
|
|
sx = original_width / original_height;
|
|
}
|
|
ctx.drawImage(
|
|
maskCanvas,
|
|
0,
|
|
0,
|
|
maskCanvas.width * sx,
|
|
maskCanvas.height * sy,
|
|
0,
|
|
0,
|
|
original_width,
|
|
original_height
|
|
);
|
|
|
|
const blob = await canvas.convertToBlob();
|
|
return URL.createObjectURL(blob);
|
|
}
|
|
|
|
self.addEventListener("message", async (event) => {
|
|
const { modelURL, modelID, imageURL, points } = event.data;
|
|
try {
|
|
self.postMessage({ status: "loading", message: "Starting SAM" });
|
|
const sam = await SAMModel.getInstance(modelURL, modelID);
|
|
|
|
self.postMessage({ status: "loading", message: "Loading Image" });
|
|
const imageArrayU8 = await fetchArrayBuffer(imageURL, false);
|
|
|
|
self.postMessage({ status: "embedding", message: "Creating Embeddings" });
|
|
SAMModel.setImageEmbeddings(imageArrayU8);
|
|
if (!points) {
|
|
// no points only do the embeddings
|
|
self.postMessage({
|
|
status: "complete-embedding",
|
|
message: "Embeddings Complete",
|
|
});
|
|
return;
|
|
}
|
|
|
|
self.postMessage({ status: "segmenting", message: "Segmenting" });
|
|
const { mask, image } = sam.mask_for_point({ points });
|
|
const maskDataURL = await createImageCanvas(mask, image);
|
|
// Send the segment back to the main thread as JSON
|
|
self.postMessage({
|
|
status: "complete",
|
|
message: "Segmentation Complete",
|
|
output: { maskURL: maskDataURL },
|
|
});
|
|
} catch (e) {
|
|
self.postMessage({ error: e });
|
|
}
|
|
});
|