From 65825e724013304e4b4664a9edfce1b356cd0e40 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Radam=C3=A9s=20Ajna?=
Date: Mon, 2 Oct 2023 15:33:46 -0700
Subject: [PATCH] [SAM] Add undo button and background point mode (#1020)
* [SAM] Add undo button and background point mode
* [SAM] remove pts on near clicks
* [SAM] check shiftKey toggle point mode
* [SAM] clear points when clearing image
---
.../segment-anything/lib-example.html | 226 +++++++++++++-----
1 file changed, 165 insertions(+), 61 deletions(-)
diff --git a/candle-wasm-examples/segment-anything/lib-example.html b/candle-wasm-examples/segment-anything/lib-example.html
index 6684f1a1..adcd02ab 100644
--- a/candle-wasm-examples/segment-anything/lib-example.html
+++ b/candle-wasm-examples/segment-anything/lib-example.html
@@ -33,7 +33,6 @@
url: "sam_vit_b_01ec64.safetensors",
},
};
- let pointArr = []
const samWorker = new Worker("./samWorker.js", { type: "module" });
async function segmentPoints(
@@ -75,6 +74,8 @@
}
const clearBtn = document.querySelector("#clear-btn");
+ const maskBtn = document.querySelector("#mask-btn");
+ const undoBtn = document.querySelector("#undo-btn");
const canvas = document.querySelector("#canvas");
const mask = document.querySelector("#mask");
const ctxCanvas = canvas.getContext("2d");
@@ -87,13 +88,14 @@
const statusOutput = document.querySelector("#output-status");
//add event listener to file input
- fileUpload.addEventListener("change", (e) => {
+ fileUpload.addEventListener("input", (e) => {
const target = e.target;
if (target.files.length > 0) {
const href = URL.createObjectURL(target.files[0]);
- cleanImageCanvas();
+ clearImageCanvas();
drawImageCanvas(href);
setImageEmbeddings(href);
+ togglePointMode(false);
}
});
// add event listener to drop-area
@@ -116,13 +118,15 @@
if (files.length > 0) {
const href = URL.createObjectURL(files[0]);
- cleanImageCanvas();
+ clearImageCanvas();
drawImageCanvas(href);
setImageEmbeddings(href);
+ togglePointMode(false);
} else if (url) {
- cleanImageCanvas();
+ clearImageCanvas();
drawImageCanvas(url);
setImageEmbeddings(url);
+ togglePointMode(false);
}
});
@@ -130,6 +134,8 @@
let isSegmenting = false;
let isEmbedding = false;
let currentImageURL = "";
+ let pointArr = [];
+ let bgPointMode = false;
//add event listener to image examples
imagesExamples.addEventListener("click", (e) => {
if (isEmbedding || isSegmenting) {
@@ -138,31 +144,91 @@
const target = e.target;
if (target.nodeName === "IMG") {
const href = target.src;
- cleanImageCanvas();
+ clearImageCanvas();
drawImageCanvas(href);
setImageEmbeddings(href);
}
});
+ //add event listener to mask button
+ maskBtn.addEventListener("click", () => {
+ togglePointMode();
+ });
//add event listener to clear button
clearBtn.addEventListener("click", () => {
- cleanImageCanvas();
- pointArr = []
+ clearImageCanvas();
+ togglePointMode(false);
+ pointArr = [];
+ });
+ //add event listener to undo button
+ undoBtn.addEventListener("click", () => {
+ undoPoint();
});
//add click event to canvas
canvas.addEventListener("click", async (event) => {
if (!hasImage || isEmbedding || isSegmenting) {
return;
}
+ const backgroundMode = event.shiftKey ? bgPointMode^event.shiftKey : bgPointMode;
const targetBox = event.target.getBoundingClientRect();
const x = (event.clientX - targetBox.left) / targetBox.width;
const y = (event.clientY - targetBox.top) / targetBox.height;
+ const ptsToRemove = [];
+ for (const [idx, pts] of pointArr.entries()) {
+ const d = Math.sqrt((pts[0] - x) ** 2 + (pts[1] - y) ** 2);
+ if (d < 6 / targetBox.width) {
+ ptsToRemove.push(idx);
+ }
+ }
+ if (ptsToRemove.length > 0) {
+ pointArr = pointArr.filter((_, idx) => !ptsToRemove.includes(idx));
+ } else {
+ pointArr = [...pointArr, [x, y, !backgroundMode]];
+ }
+ undoBtn.disabled = false;
+ if (pointArr.length == 0) {
+ ctxMask.clearRect(0, 0, canvas.width, canvas.height);
+ undoBtn.disabled = true;
+ return;
+ }
isSegmenting = true;
- pointArr = [...pointArr, [ x, y , true ]]
const { maskURL } = await getSegmentationMask(pointArr);
isSegmenting = false;
- drawMask(maskURL);
+ drawMask(maskURL, pointArr);
});
+ async function undoPoint() {
+ if (!hasImage || isEmbedding || isSegmenting) {
+ return;
+ }
+ if (pointArr.length === 0) {
+ return;
+ }
+ pointArr.pop();
+ if (pointArr.length === 0) {
+ ctxMask.clearRect(0, 0, canvas.width, canvas.height);
+ undoBtn.disabled = true;
+ return;
+ }
+ isSegmenting = true;
+ const { maskURL } = await getSegmentationMask(pointArr);
+ isSegmenting = false;
+ drawMask(maskURL, pointArr);
+ }
+ function togglePointMode(mode) {
+ bgPointMode = mode === undefined ? !bgPointMode : mode;
+
+ maskBtn.querySelector("span").innerText = bgPointMode
+ ? "Background Point"
+ : "Mask Point";
+ if (bgPointMode) {
+ maskBtn.querySelector("#mask-circle").setAttribute("hidden", "");
+ maskBtn.querySelector("#unmask-circle").removeAttribute("hidden");
+ } else {
+ maskBtn.querySelector("#mask-circle").removeAttribute("hidden");
+ maskBtn.querySelector("#unmask-circle").setAttribute("hidden", "");
+ }
+ }
+
async function getSegmentationMask(points) {
const modelID = modelSelection.value;
const modelURL = MODEL_BASEURL + MODELS[modelID].url;
@@ -193,18 +259,19 @@
currentImageURL = imageURL;
}
- function cleanImageCanvas() {
+ function clearImageCanvas() {
ctxCanvas.clearRect(0, 0, canvas.width, canvas.height);
ctxMask.clearRect(0, 0, canvas.width, canvas.height);
hasImage = false;
isEmbedding = false;
isSegmenting = false;
currentImageURL = "";
- clearBtn.classList.add("invisible");
+ pointArr = [];
+ clearBtn.disabled = true;
canvas.parentElement.style.height = "auto";
dropButtons.classList.remove("invisible");
}
- function drawMask(maskURL) {
+ function drawMask(maskURL, points) {
if (!maskURL) {
throw new Error("No mask URL provided");
}
@@ -215,12 +282,31 @@
img.onload = () => {
mask.width = canvas.width;
mask.height = canvas.height;
+ ctxMask.save();
ctxMask.drawImage(canvas, 0, 0);
ctxMask.globalCompositeOperation = "source-atop";
ctxMask.fillStyle = "rgba(255, 0, 0, 0.6)";
ctxMask.fillRect(0, 0, canvas.width, canvas.height);
ctxMask.globalCompositeOperation = "destination-in";
ctxMask.drawImage(img, 0, 0);
+ ctxMask.globalCompositeOperation = "source-over";
+ for (const pt of points) {
+ if (pt[2]) {
+ ctxMask.fillStyle = "rgba(0, 255, 255, 1)";
+ } else {
+ ctxMask.fillStyle = "rgba(255, 255, 0, 1)";
+ }
+ ctxMask.beginPath();
+ ctxMask.arc(
+ pt[0] * canvas.width,
+ pt[1] * canvas.height,
+ 3,
+ 0,
+ 2 * Math.PI
+ );
+ ctxMask.fill();
+ }
+ ctxMask.restore();
};
img.src = maskURL;
}
@@ -241,7 +327,7 @@
ctxCanvas.drawImage(img, 0, 0);
canvas.parentElement.style.height = canvas.offsetHeight + "px";
hasImage = true;
- clearBtn.classList.remove("invisible");
+ clearBtn.disabled = false;
dropButtons.classList.add("invisible");
};
img.src = imgURL;
@@ -290,8 +376,7 @@
-
+
-
+
+
+
+
+
+
+ class="flex flex-col items-center justify-center border-2 border-gray-300 border-dashed rounded-xl relative p-20 w-full overflow-hidden">
+ class="flex flex-col items-center justify-center space-y-1 text-center relative z-10">
+ class="pointer-events-none absolute w-full">