From 47c25a567bd14ab3e830b5d768cd80f33ed9545b Mon Sep 17 00:00:00 2001 From: lichin-lin Date: Thu, 5 Oct 2023 22:14:47 +0100 Subject: [PATCH] feat: [SAM] able to download the result as png (#1035) * feat: able to download the result as png * feat: update function and wording --- .../segment-anything/lib-example.html | 60 +++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/candle-wasm-examples/segment-anything/lib-example.html b/candle-wasm-examples/segment-anything/lib-example.html index adcd02ab..f6b5931f 100644 --- a/candle-wasm-examples/segment-anything/lib-example.html +++ b/candle-wasm-examples/segment-anything/lib-example.html @@ -73,9 +73,12 @@ statusOutput.innerText = event.data.message; } + let copyMaskURL = null; + let copyImageURL = null; const clearBtn = document.querySelector("#clear-btn"); const maskBtn = document.querySelector("#mask-btn"); const undoBtn = document.querySelector("#undo-btn"); + const downloadBtn = document.querySelector("#download-btn"); const canvas = document.querySelector("#canvas"); const mask = document.querySelector("#mask"); const ctxCanvas = canvas.getContext("2d"); @@ -93,6 +96,7 @@ if (target.files.length > 0) { const href = URL.createObjectURL(target.files[0]); clearImageCanvas(); + copyImageURL = href; drawImageCanvas(href); setImageEmbeddings(href); togglePointMode(false); @@ -119,11 +123,13 @@ if (files.length > 0) { const href = URL.createObjectURL(files[0]); clearImageCanvas(); + copyImageURL = href; drawImageCanvas(href); setImageEmbeddings(href); togglePointMode(false); } else if (url) { clearImageCanvas(); + copyImageURL = url; drawImageCanvas(url); setImageEmbeddings(url); togglePointMode(false); @@ -145,6 +151,7 @@ if (target.nodeName === "IMG") { const href = target.src; clearImageCanvas(); + copyImageURL = href; drawImageCanvas(href); setImageEmbeddings(href); } @@ -163,6 +170,46 @@ undoBtn.addEventListener("click", () => { undoPoint(); }); + // add event to download btn + downloadBtn.addEventListener("click", async () => { + // Function to load image blobs as Image elements asynchronously + const loadImageAsync = (imageURL) => { + return new Promise((resolve) => { + const img = new Image(); + img.onload = () => { + resolve(img); + }; + img.crossOrigin = "anonymous"; + img.src = imageURL; + }); + }; + const originalImage = await loadImageAsync(copyImageURL); + const maskImage = await loadImageAsync(copyMaskURL); + + // create main a board to draw + const canvas = document.createElement("canvas"); + const ctx = canvas.getContext("2d"); + canvas.width = originalImage.width; + canvas.height = originalImage.height; + + // Perform the mask operation + ctx.drawImage(maskImage, 0, 0); + ctx.globalCompositeOperation = "source-in"; + ctx.drawImage(originalImage, 0, 0); + + // to blob + const blobPromise = new Promise((resolve) => { + canvas.toBlob(resolve); + }); + const blob = await blobPromise; + const resultURL = URL.createObjectURL(blob); + + // download + const link = document.createElement("a"); + link.href = resultURL; + link.download = "cutout.png"; + link.click(); + }); //add click event to canvas canvas.addEventListener("click", async (event) => { if (!hasImage || isEmbedding || isSegmenting) { @@ -185,14 +232,17 @@ pointArr = [...pointArr, [x, y, !backgroundMode]]; } undoBtn.disabled = false; + downloadBtn.disabled = false; if (pointArr.length == 0) { ctxMask.clearRect(0, 0, canvas.width, canvas.height); undoBtn.disabled = true; + downloadBtn.disabled = true; return; } isSegmenting = true; const { maskURL } = await getSegmentationMask(pointArr); isSegmenting = false; + copyMaskURL = maskURL; drawMask(maskURL, pointArr); }); @@ -212,6 +262,7 @@ isSegmenting = true; const { maskURL } = await getSegmentationMask(pointArr); isSegmenting = false; + copyMaskURL = maskURL; drawMask(maskURL, pointArr); } function togglePointMode(mode) { @@ -490,6 +541,15 @@ + +