[SAM] Add undo button and background point mode (#1020)

* [SAM] Add undo button and background point mode

* [SAM] remove pts on near clicks

* [SAM] check shiftKey toggle point mode

* [SAM] clear points when clearing image
This commit is contained in:
Radamés Ajna
2023-10-02 15:33:46 -07:00
committed by GitHub
parent 7670fe7d1f
commit 65825e7240

View File

@ -33,7 +33,6 @@
url: "sam_vit_b_01ec64.safetensors", url: "sam_vit_b_01ec64.safetensors",
}, },
}; };
let pointArr = []
const samWorker = new Worker("./samWorker.js", { type: "module" }); const samWorker = new Worker("./samWorker.js", { type: "module" });
async function segmentPoints( async function segmentPoints(
@ -75,6 +74,8 @@
} }
const clearBtn = document.querySelector("#clear-btn"); const clearBtn = document.querySelector("#clear-btn");
const maskBtn = document.querySelector("#mask-btn");
const undoBtn = document.querySelector("#undo-btn");
const canvas = document.querySelector("#canvas"); const canvas = document.querySelector("#canvas");
const mask = document.querySelector("#mask"); const mask = document.querySelector("#mask");
const ctxCanvas = canvas.getContext("2d"); const ctxCanvas = canvas.getContext("2d");
@ -87,13 +88,14 @@
const statusOutput = document.querySelector("#output-status"); const statusOutput = document.querySelector("#output-status");
//add event listener to file input //add event listener to file input
fileUpload.addEventListener("change", (e) => { fileUpload.addEventListener("input", (e) => {
const target = e.target; const target = e.target;
if (target.files.length > 0) { if (target.files.length > 0) {
const href = URL.createObjectURL(target.files[0]); const href = URL.createObjectURL(target.files[0]);
cleanImageCanvas(); clearImageCanvas();
drawImageCanvas(href); drawImageCanvas(href);
setImageEmbeddings(href); setImageEmbeddings(href);
togglePointMode(false);
} }
}); });
// add event listener to drop-area // add event listener to drop-area
@ -116,13 +118,15 @@
if (files.length > 0) { if (files.length > 0) {
const href = URL.createObjectURL(files[0]); const href = URL.createObjectURL(files[0]);
cleanImageCanvas(); clearImageCanvas();
drawImageCanvas(href); drawImageCanvas(href);
setImageEmbeddings(href); setImageEmbeddings(href);
togglePointMode(false);
} else if (url) { } else if (url) {
cleanImageCanvas(); clearImageCanvas();
drawImageCanvas(url); drawImageCanvas(url);
setImageEmbeddings(url); setImageEmbeddings(url);
togglePointMode(false);
} }
}); });
@ -130,6 +134,8 @@
let isSegmenting = false; let isSegmenting = false;
let isEmbedding = false; let isEmbedding = false;
let currentImageURL = ""; let currentImageURL = "";
let pointArr = [];
let bgPointMode = false;
//add event listener to image examples //add event listener to image examples
imagesExamples.addEventListener("click", (e) => { imagesExamples.addEventListener("click", (e) => {
if (isEmbedding || isSegmenting) { if (isEmbedding || isSegmenting) {
@ -138,31 +144,91 @@
const target = e.target; const target = e.target;
if (target.nodeName === "IMG") { if (target.nodeName === "IMG") {
const href = target.src; const href = target.src;
cleanImageCanvas(); clearImageCanvas();
drawImageCanvas(href); drawImageCanvas(href);
setImageEmbeddings(href); setImageEmbeddings(href);
} }
}); });
//add event listener to mask button
maskBtn.addEventListener("click", () => {
togglePointMode();
});
//add event listener to clear button //add event listener to clear button
clearBtn.addEventListener("click", () => { clearBtn.addEventListener("click", () => {
cleanImageCanvas(); clearImageCanvas();
pointArr = [] togglePointMode(false);
pointArr = [];
});
//add event listener to undo button
undoBtn.addEventListener("click", () => {
undoPoint();
}); });
//add click event to canvas //add click event to canvas
canvas.addEventListener("click", async (event) => { canvas.addEventListener("click", async (event) => {
if (!hasImage || isEmbedding || isSegmenting) { if (!hasImage || isEmbedding || isSegmenting) {
return; return;
} }
const backgroundMode = event.shiftKey ? bgPointMode^event.shiftKey : bgPointMode;
const targetBox = event.target.getBoundingClientRect(); const targetBox = event.target.getBoundingClientRect();
const x = (event.clientX - targetBox.left) / targetBox.width; const x = (event.clientX - targetBox.left) / targetBox.width;
const y = (event.clientY - targetBox.top) / targetBox.height; const y = (event.clientY - targetBox.top) / targetBox.height;
const ptsToRemove = [];
for (const [idx, pts] of pointArr.entries()) {
const d = Math.sqrt((pts[0] - x) ** 2 + (pts[1] - y) ** 2);
if (d < 6 / targetBox.width) {
ptsToRemove.push(idx);
}
}
if (ptsToRemove.length > 0) {
pointArr = pointArr.filter((_, idx) => !ptsToRemove.includes(idx));
} else {
pointArr = [...pointArr, [x, y, !backgroundMode]];
}
undoBtn.disabled = false;
if (pointArr.length == 0) {
ctxMask.clearRect(0, 0, canvas.width, canvas.height);
undoBtn.disabled = true;
return;
}
isSegmenting = true; isSegmenting = true;
pointArr = [...pointArr, [ x, y , true ]]
const { maskURL } = await getSegmentationMask(pointArr); const { maskURL } = await getSegmentationMask(pointArr);
isSegmenting = false; isSegmenting = false;
drawMask(maskURL); drawMask(maskURL, pointArr);
}); });
async function undoPoint() {
if (!hasImage || isEmbedding || isSegmenting) {
return;
}
if (pointArr.length === 0) {
return;
}
pointArr.pop();
if (pointArr.length === 0) {
ctxMask.clearRect(0, 0, canvas.width, canvas.height);
undoBtn.disabled = true;
return;
}
isSegmenting = true;
const { maskURL } = await getSegmentationMask(pointArr);
isSegmenting = false;
drawMask(maskURL, pointArr);
}
function togglePointMode(mode) {
bgPointMode = mode === undefined ? !bgPointMode : mode;
maskBtn.querySelector("span").innerText = bgPointMode
? "Background Point"
: "Mask Point";
if (bgPointMode) {
maskBtn.querySelector("#mask-circle").setAttribute("hidden", "");
maskBtn.querySelector("#unmask-circle").removeAttribute("hidden");
} else {
maskBtn.querySelector("#mask-circle").removeAttribute("hidden");
maskBtn.querySelector("#unmask-circle").setAttribute("hidden", "");
}
}
async function getSegmentationMask(points) { async function getSegmentationMask(points) {
const modelID = modelSelection.value; const modelID = modelSelection.value;
const modelURL = MODEL_BASEURL + MODELS[modelID].url; const modelURL = MODEL_BASEURL + MODELS[modelID].url;
@ -193,18 +259,19 @@
currentImageURL = imageURL; currentImageURL = imageURL;
} }
function cleanImageCanvas() { function clearImageCanvas() {
ctxCanvas.clearRect(0, 0, canvas.width, canvas.height); ctxCanvas.clearRect(0, 0, canvas.width, canvas.height);
ctxMask.clearRect(0, 0, canvas.width, canvas.height); ctxMask.clearRect(0, 0, canvas.width, canvas.height);
hasImage = false; hasImage = false;
isEmbedding = false; isEmbedding = false;
isSegmenting = false; isSegmenting = false;
currentImageURL = ""; currentImageURL = "";
clearBtn.classList.add("invisible"); pointArr = [];
clearBtn.disabled = true;
canvas.parentElement.style.height = "auto"; canvas.parentElement.style.height = "auto";
dropButtons.classList.remove("invisible"); dropButtons.classList.remove("invisible");
} }
function drawMask(maskURL) { function drawMask(maskURL, points) {
if (!maskURL) { if (!maskURL) {
throw new Error("No mask URL provided"); throw new Error("No mask URL provided");
} }
@ -215,12 +282,31 @@
img.onload = () => { img.onload = () => {
mask.width = canvas.width; mask.width = canvas.width;
mask.height = canvas.height; mask.height = canvas.height;
ctxMask.save();
ctxMask.drawImage(canvas, 0, 0); ctxMask.drawImage(canvas, 0, 0);
ctxMask.globalCompositeOperation = "source-atop"; ctxMask.globalCompositeOperation = "source-atop";
ctxMask.fillStyle = "rgba(255, 0, 0, 0.6)"; ctxMask.fillStyle = "rgba(255, 0, 0, 0.6)";
ctxMask.fillRect(0, 0, canvas.width, canvas.height); ctxMask.fillRect(0, 0, canvas.width, canvas.height);
ctxMask.globalCompositeOperation = "destination-in"; ctxMask.globalCompositeOperation = "destination-in";
ctxMask.drawImage(img, 0, 0); ctxMask.drawImage(img, 0, 0);
ctxMask.globalCompositeOperation = "source-over";
for (const pt of points) {
if (pt[2]) {
ctxMask.fillStyle = "rgba(0, 255, 255, 1)";
} else {
ctxMask.fillStyle = "rgba(255, 255, 0, 1)";
}
ctxMask.beginPath();
ctxMask.arc(
pt[0] * canvas.width,
pt[1] * canvas.height,
3,
0,
2 * Math.PI
);
ctxMask.fill();
}
ctxMask.restore();
}; };
img.src = maskURL; img.src = maskURL;
} }
@ -241,7 +327,7 @@
ctxCanvas.drawImage(img, 0, 0); ctxCanvas.drawImage(img, 0, 0);
canvas.parentElement.style.height = canvas.offsetHeight + "px"; canvas.parentElement.style.height = canvas.offsetHeight + "px";
hasImage = true; hasImage = true;
clearBtn.classList.remove("invisible"); clearBtn.disabled = false;
dropButtons.classList.add("invisible"); dropButtons.classList.add("invisible");
}; };
img.src = imgURL; img.src = imgURL;
@ -290,8 +376,7 @@
<label for="model" class="font-medium">Models Options: </label> <label for="model" class="font-medium">Models Options: </label>
<select <select
id="model" id="model"
class="border-2 border-gray-500 rounded-md font-light" class="border-2 border-gray-500 rounded-md font-light">
>
<option value="sam_mobile_tiny" selected> <option value="sam_mobile_tiny" selected>
Mobile SAM Tiny (40.6 MB) Mobile SAM Tiny (40.6 MB)
</option> </option>
@ -306,55 +391,82 @@
subsequent clicks on points will be significantly faster. subsequent clicks on points will be significantly faster.
</p> </p>
</div> </div>
<div class="relative max-w-lg"> <div class="relative max-w-2xl">
<div class="flex justify-between items-center"> <div class="flex justify-between items-center">
<div class="px-2 rounded-md inline text-xs"> <div class="px-2 rounded-md inline text-xs">
<span id="output-status" class="m-auto font-light"></span> <span id="output-status" class="m-auto font-light"></span>
</div> </div>
<button <div class="flex gap-2">
id="clear-btn" <button
class="text-xs bg-white rounded-md disabled:opacity-50 flex gap-1 items-center invisible" id="mask-btn"
> title="Toggle Mask Point and Background Point"
<svg class="text-xs bg-white rounded-md disabled:opacity-50 flex gap-1 items-center">
class="" <span>Mask Point</span>
xmlns="http://www.w3.org/2000/svg" <svg
viewBox="0 0 13 12" xmlns="http://www.w3.org/2000/svg"
height="1em" height="1em"
> viewBox="0 0 512 512">
<path <path
d="M1.6.7 12 11.1M12 .7 1.6 11.1" id="mask-circle"
stroke="#2E3036" d="M256 512a256 256 0 1 0 0-512 256 256 0 1 0 0 512z" />
stroke-width="2" <path
/> id="unmask-circle"
</svg> hidden
Clear image d="M464 256a208 208 0 1 0-416 0 208 208 0 1 0 416 0zM0 256a256 256 0 1 1 512 0 256 256 0 1 1-512 0z" />
</button> </svg>
</button>
<button
id="undo-btn"
disabled
title="Undo Last Point"
class="text-xs bg-white rounded-md disabled:opacity-50 flex gap-1 items-center">
<svg
xmlns="http://www.w3.org/2000/svg"
height="1em"
viewBox="0 0 512 512">
<path
d="M48.5 224H40a24 24 0 0 1-24-24V72a24 24 0 0 1 41-17l41.6 41.6a224 224 0 1 1-1 317.8 32 32 0 0 1 45.3-45.3 160 160 0 1 0 1-227.3L185 183a24 24 0 0 1-17 41H48.5z" />
</svg>
</button>
<button
id="clear-btn"
disabled
title="Clear Image"
class="text-xs bg-white rounded-md disabled:opacity-50 flex gap-1 items-center">
<svg
class=""
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 13 12"
height="1em">
<path
d="M1.6.7 12 11.1M12 .7 1.6 11.1"
stroke="#2E3036"
stroke-width="2" />
</svg>
</button>
</div>
</div> </div>
<div <div
id="drop-area" id="drop-area"
class="flex flex-col items-center justify-center border-2 border-gray-300 border-dashed rounded-xl relative p-20 w-full overflow-hidden" class="flex flex-col items-center justify-center border-2 border-gray-300 border-dashed rounded-xl relative p-20 w-full overflow-hidden">
>
<div <div
id="drop-buttons" id="drop-buttons"
class="flex flex-col items-center justify-center space-y-1 text-center relative z-10" class="flex flex-col items-center justify-center space-y-1 text-center relative z-10">
>
<svg <svg
width="25" width="25"
height="25" height="25"
viewBox="0 0 25 25" viewBox="0 0 25 25"
fill="none" fill="none"
xmlns="http://www.w3.org/2000/svg" xmlns="http://www.w3.org/2000/svg">
>
<path <path
d="M3.5 24.3a3 3 0 0 1-1.9-.8c-.5-.5-.8-1.2-.8-1.9V2.9c0-.7.3-1.3.8-1.9.6-.5 1.2-.7 2-.7h18.6c.7 0 1.3.2 1.9.7.5.6.7 1.2.7 2v18.6c0 .7-.2 1.4-.7 1.9a3 3 0 0 1-2 .8H3.6Zm0-2.7h18.7V2.9H3.5v18.7Zm2.7-2.7h13.3c.3 0 .5 0 .6-.3v-.7l-3.7-5a.6.6 0 0 0-.6-.2c-.2 0-.4 0-.5.3l-3.5 4.6-2.4-3.3a.6.6 0 0 0-.6-.3c-.2 0-.4.1-.5.3l-2.7 3.6c-.1.2-.2.4 0 .7.1.2.3.3.6.3Z" d="M3.5 24.3a3 3 0 0 1-1.9-.8c-.5-.5-.8-1.2-.8-1.9V2.9c0-.7.3-1.3.8-1.9.6-.5 1.2-.7 2-.7h18.6c.7 0 1.3.2 1.9.7.5.6.7 1.2.7 2v18.6c0 .7-.2 1.4-.7 1.9a3 3 0 0 1-2 .8H3.6Zm0-2.7h18.7V2.9H3.5v18.7Zm2.7-2.7h13.3c.3 0 .5 0 .6-.3v-.7l-3.7-5a.6.6 0 0 0-.6-.2c-.2 0-.4 0-.5.3l-3.5 4.6-2.4-3.3a.6.6 0 0 0-.6-.3c-.2 0-.4.1-.5.3l-2.7 3.6c-.1.2-.2.4 0 .7.1.2.3.3.6.3Z"
fill="#000" fill="#000" />
/>
</svg> </svg>
<div class="flex text-sm text-gray-600"> <div class="flex text-sm text-gray-600">
<label <label
for="file-upload" for="file-upload"
class="relative cursor-pointer bg-white rounded-md font-medium text-blue-950 hover:text-blue-700" class="relative cursor-pointer bg-white rounded-md font-medium text-blue-950 hover:text-blue-700">
>
<span>Drag and drop your image here</span> <span>Drag and drop your image here</span>
<span class="block text-xs">or</span> <span class="block text-xs">or</span>
<span class="block text-xs">Click to upload</span> <span class="block text-xs">Click to upload</span>
@ -364,45 +476,37 @@
id="file-upload" id="file-upload"
name="file-upload" name="file-upload"
type="file" type="file"
class="sr-only" class="sr-only" />
/>
</div> </div>
<canvas id="canvas" class="absolute w-full"></canvas> <canvas id="canvas" class="absolute w-full"></canvas>
<canvas <canvas
id="mask" id="mask"
class="pointer-events-none absolute w-full" class="pointer-events-none absolute w-full"></canvas>
></canvas>
</div> </div>
<div class="text-right py-2"> <div class="text-right py-2">
<button <button
id="share-btn" id="share-btn"
class="bg-white rounded-md hover:outline outline-orange-200 disabled:opacity-50 invisible" class="bg-white rounded-md hover:outline outline-orange-200 disabled:opacity-50 invisible">
>
<img <img
src="https://huggingface.co/datasets/huggingface/badges/raw/main/share-to-community-sm.svg" src="https://huggingface.co/datasets/huggingface/badges/raw/main/share-to-community-sm.svg" />
/>
</button> </button>
</div> </div>
</div> </div>
<div> <div>
<div <div
class="flex gap-3 items-center overflow-x-scroll" class="flex gap-3 items-center overflow-x-scroll"
id="image-select" id="image-select">
>
<h3 class="font-medium">Examples:</h3> <h3 class="font-medium">Examples:</h3>
<img <img
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/candle/examples/sf.jpg" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/candle/examples/sf.jpg"
class="cursor-pointer w-24 h-24 object-cover" class="cursor-pointer w-24 h-24 object-cover" />
/>
<img <img
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/candle/examples/bike.jpeg" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/candle/examples/bike.jpeg"
class="cursor-pointer w-24 h-24 object-cover" class="cursor-pointer w-24 h-24 object-cover" />
/>
<img <img
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/candle/examples/000000000077.jpg" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/candle/examples/000000000077.jpg"
class="cursor-pointer w-24 h-24 object-cover" class="cursor-pointer w-24 h-24 object-cover" />
/>
</div> </div>
</div> </div>
</main> </main>