diff --git a/candle-wasm-examples/segment-anything/lib-example.html b/candle-wasm-examples/segment-anything/lib-example.html
index adcd02ab..f6b5931f 100644
--- a/candle-wasm-examples/segment-anything/lib-example.html
+++ b/candle-wasm-examples/segment-anything/lib-example.html
@@ -73,9 +73,12 @@
statusOutput.innerText = event.data.message;
}
+ let copyMaskURL = null;
+ let copyImageURL = null;
const clearBtn = document.querySelector("#clear-btn");
const maskBtn = document.querySelector("#mask-btn");
const undoBtn = document.querySelector("#undo-btn");
+ const downloadBtn = document.querySelector("#download-btn");
const canvas = document.querySelector("#canvas");
const mask = document.querySelector("#mask");
const ctxCanvas = canvas.getContext("2d");
@@ -93,6 +96,7 @@
if (target.files.length > 0) {
const href = URL.createObjectURL(target.files[0]);
clearImageCanvas();
+ copyImageURL = href;
drawImageCanvas(href);
setImageEmbeddings(href);
togglePointMode(false);
@@ -119,11 +123,13 @@
if (files.length > 0) {
const href = URL.createObjectURL(files[0]);
clearImageCanvas();
+ copyImageURL = href;
drawImageCanvas(href);
setImageEmbeddings(href);
togglePointMode(false);
} else if (url) {
clearImageCanvas();
+ copyImageURL = url;
drawImageCanvas(url);
setImageEmbeddings(url);
togglePointMode(false);
@@ -145,6 +151,7 @@
if (target.nodeName === "IMG") {
const href = target.src;
clearImageCanvas();
+ copyImageURL = href;
drawImageCanvas(href);
setImageEmbeddings(href);
}
@@ -163,6 +170,46 @@
undoBtn.addEventListener("click", () => {
undoPoint();
});
+ // add event to download btn
+ downloadBtn.addEventListener("click", async () => {
+ // Function to load image blobs as Image elements asynchronously
+ const loadImageAsync = (imageURL) => {
+ return new Promise((resolve) => {
+ const img = new Image();
+ img.onload = () => {
+ resolve(img);
+ };
+ img.crossOrigin = "anonymous";
+ img.src = imageURL;
+ });
+ };
+ const originalImage = await loadImageAsync(copyImageURL);
+ const maskImage = await loadImageAsync(copyMaskURL);
+
+ // create main a board to draw
+ const canvas = document.createElement("canvas");
+ const ctx = canvas.getContext("2d");
+ canvas.width = originalImage.width;
+ canvas.height = originalImage.height;
+
+ // Perform the mask operation
+ ctx.drawImage(maskImage, 0, 0);
+ ctx.globalCompositeOperation = "source-in";
+ ctx.drawImage(originalImage, 0, 0);
+
+ // to blob
+ const blobPromise = new Promise((resolve) => {
+ canvas.toBlob(resolve);
+ });
+ const blob = await blobPromise;
+ const resultURL = URL.createObjectURL(blob);
+
+ // download
+ const link = document.createElement("a");
+ link.href = resultURL;
+ link.download = "cutout.png";
+ link.click();
+ });
//add click event to canvas
canvas.addEventListener("click", async (event) => {
if (!hasImage || isEmbedding || isSegmenting) {
@@ -185,14 +232,17 @@
pointArr = [...pointArr, [x, y, !backgroundMode]];
}
undoBtn.disabled = false;
+ downloadBtn.disabled = false;
if (pointArr.length == 0) {
ctxMask.clearRect(0, 0, canvas.width, canvas.height);
undoBtn.disabled = true;
+ downloadBtn.disabled = true;
return;
}
isSegmenting = true;
const { maskURL } = await getSegmentationMask(pointArr);
isSegmenting = false;
+ copyMaskURL = maskURL;
drawMask(maskURL, pointArr);
});
@@ -212,6 +262,7 @@
isSegmenting = true;
const { maskURL } = await getSegmentationMask(pointArr);
isSegmenting = false;
+ copyMaskURL = maskURL;
drawMask(maskURL, pointArr);
}
function togglePointMode(mode) {
@@ -490,6 +541,15 @@
+
+