From 65825e724013304e4b4664a9edfce1b356cd0e40 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Radam=C3=A9s=20Ajna?= Date: Mon, 2 Oct 2023 15:33:46 -0700 Subject: [PATCH] [SAM] Add undo button and background point mode (#1020) * [SAM] Add undo button and background point mode * [SAM] remove pts on near clicks * [SAM] check shiftKey toggle point mode * [SAM] clear points when clearing image --- .../segment-anything/lib-example.html | 226 +++++++++++++----- 1 file changed, 165 insertions(+), 61 deletions(-) diff --git a/candle-wasm-examples/segment-anything/lib-example.html b/candle-wasm-examples/segment-anything/lib-example.html index 6684f1a1..adcd02ab 100644 --- a/candle-wasm-examples/segment-anything/lib-example.html +++ b/candle-wasm-examples/segment-anything/lib-example.html @@ -33,7 +33,6 @@ url: "sam_vit_b_01ec64.safetensors", }, }; - let pointArr = [] const samWorker = new Worker("./samWorker.js", { type: "module" }); async function segmentPoints( @@ -75,6 +74,8 @@ } const clearBtn = document.querySelector("#clear-btn"); + const maskBtn = document.querySelector("#mask-btn"); + const undoBtn = document.querySelector("#undo-btn"); const canvas = document.querySelector("#canvas"); const mask = document.querySelector("#mask"); const ctxCanvas = canvas.getContext("2d"); @@ -87,13 +88,14 @@ const statusOutput = document.querySelector("#output-status"); //add event listener to file input - fileUpload.addEventListener("change", (e) => { + fileUpload.addEventListener("input", (e) => { const target = e.target; if (target.files.length > 0) { const href = URL.createObjectURL(target.files[0]); - cleanImageCanvas(); + clearImageCanvas(); drawImageCanvas(href); setImageEmbeddings(href); + togglePointMode(false); } }); // add event listener to drop-area @@ -116,13 +118,15 @@ if (files.length > 0) { const href = URL.createObjectURL(files[0]); - cleanImageCanvas(); + clearImageCanvas(); drawImageCanvas(href); setImageEmbeddings(href); + togglePointMode(false); } else if (url) { - cleanImageCanvas(); + clearImageCanvas(); drawImageCanvas(url); setImageEmbeddings(url); + togglePointMode(false); } }); @@ -130,6 +134,8 @@ let isSegmenting = false; let isEmbedding = false; let currentImageURL = ""; + let pointArr = []; + let bgPointMode = false; //add event listener to image examples imagesExamples.addEventListener("click", (e) => { if (isEmbedding || isSegmenting) { @@ -138,31 +144,91 @@ const target = e.target; if (target.nodeName === "IMG") { const href = target.src; - cleanImageCanvas(); + clearImageCanvas(); drawImageCanvas(href); setImageEmbeddings(href); } }); + //add event listener to mask button + maskBtn.addEventListener("click", () => { + togglePointMode(); + }); //add event listener to clear button clearBtn.addEventListener("click", () => { - cleanImageCanvas(); - pointArr = [] + clearImageCanvas(); + togglePointMode(false); + pointArr = []; + }); + //add event listener to undo button + undoBtn.addEventListener("click", () => { + undoPoint(); }); //add click event to canvas canvas.addEventListener("click", async (event) => { if (!hasImage || isEmbedding || isSegmenting) { return; } + const backgroundMode = event.shiftKey ? bgPointMode^event.shiftKey : bgPointMode; const targetBox = event.target.getBoundingClientRect(); const x = (event.clientX - targetBox.left) / targetBox.width; const y = (event.clientY - targetBox.top) / targetBox.height; + const ptsToRemove = []; + for (const [idx, pts] of pointArr.entries()) { + const d = Math.sqrt((pts[0] - x) ** 2 + (pts[1] - y) ** 2); + if (d < 6 / targetBox.width) { + ptsToRemove.push(idx); + } + } + if (ptsToRemove.length > 0) { + pointArr = pointArr.filter((_, idx) => !ptsToRemove.includes(idx)); + } else { + pointArr = [...pointArr, [x, y, !backgroundMode]]; + } + undoBtn.disabled = false; + if (pointArr.length == 0) { + ctxMask.clearRect(0, 0, canvas.width, canvas.height); + undoBtn.disabled = true; + return; + } isSegmenting = true; - pointArr = [...pointArr, [ x, y , true ]] const { maskURL } = await getSegmentationMask(pointArr); isSegmenting = false; - drawMask(maskURL); + drawMask(maskURL, pointArr); }); + async function undoPoint() { + if (!hasImage || isEmbedding || isSegmenting) { + return; + } + if (pointArr.length === 0) { + return; + } + pointArr.pop(); + if (pointArr.length === 0) { + ctxMask.clearRect(0, 0, canvas.width, canvas.height); + undoBtn.disabled = true; + return; + } + isSegmenting = true; + const { maskURL } = await getSegmentationMask(pointArr); + isSegmenting = false; + drawMask(maskURL, pointArr); + } + function togglePointMode(mode) { + bgPointMode = mode === undefined ? !bgPointMode : mode; + + maskBtn.querySelector("span").innerText = bgPointMode + ? "Background Point" + : "Mask Point"; + if (bgPointMode) { + maskBtn.querySelector("#mask-circle").setAttribute("hidden", ""); + maskBtn.querySelector("#unmask-circle").removeAttribute("hidden"); + } else { + maskBtn.querySelector("#mask-circle").removeAttribute("hidden"); + maskBtn.querySelector("#unmask-circle").setAttribute("hidden", ""); + } + } + async function getSegmentationMask(points) { const modelID = modelSelection.value; const modelURL = MODEL_BASEURL + MODELS[modelID].url; @@ -193,18 +259,19 @@ currentImageURL = imageURL; } - function cleanImageCanvas() { + function clearImageCanvas() { ctxCanvas.clearRect(0, 0, canvas.width, canvas.height); ctxMask.clearRect(0, 0, canvas.width, canvas.height); hasImage = false; isEmbedding = false; isSegmenting = false; currentImageURL = ""; - clearBtn.classList.add("invisible"); + pointArr = []; + clearBtn.disabled = true; canvas.parentElement.style.height = "auto"; dropButtons.classList.remove("invisible"); } - function drawMask(maskURL) { + function drawMask(maskURL, points) { if (!maskURL) { throw new Error("No mask URL provided"); } @@ -215,12 +282,31 @@ img.onload = () => { mask.width = canvas.width; mask.height = canvas.height; + ctxMask.save(); ctxMask.drawImage(canvas, 0, 0); ctxMask.globalCompositeOperation = "source-atop"; ctxMask.fillStyle = "rgba(255, 0, 0, 0.6)"; ctxMask.fillRect(0, 0, canvas.width, canvas.height); ctxMask.globalCompositeOperation = "destination-in"; ctxMask.drawImage(img, 0, 0); + ctxMask.globalCompositeOperation = "source-over"; + for (const pt of points) { + if (pt[2]) { + ctxMask.fillStyle = "rgba(0, 255, 255, 1)"; + } else { + ctxMask.fillStyle = "rgba(255, 255, 0, 1)"; + } + ctxMask.beginPath(); + ctxMask.arc( + pt[0] * canvas.width, + pt[1] * canvas.height, + 3, + 0, + 2 * Math.PI + ); + ctxMask.fill(); + } + ctxMask.restore(); }; img.src = maskURL; } @@ -241,7 +327,7 @@ ctxCanvas.drawImage(img, 0, 0); canvas.parentElement.style.height = canvas.offsetHeight + "px"; hasImage = true; - clearBtn.classList.remove("invisible"); + clearBtn.disabled = false; dropButtons.classList.add("invisible"); }; img.src = imgURL; @@ -290,8 +376,7 @@