[segment-anything] add multi point logic for demo site (#1002)

* [segment-anything] add multi point logic for demo site

* [segment-anything] remove libs and update functions
This commit is contained in:
lichin-lin
2023-10-01 18:25:22 +01:00
committed by GitHub
parent 096dee7073
commit 41143db1af
3 changed files with 29 additions and 14 deletions

View File

@ -33,6 +33,7 @@
url: "sam_vit_b_01ec64.safetensors",
},
};
let pointArr = []
const samWorker = new Worker("./samWorker.js", { type: "module" });
async function segmentPoints(
@ -145,6 +146,7 @@
//add event listener to clear button
clearBtn.addEventListener("click", () => {
cleanImageCanvas();
pointArr = []
});
//add click event to canvas
canvas.addEventListener("click", async (event) => {
@ -155,7 +157,8 @@
const x = (event.clientX - targetBox.left) / targetBox.width;
const y = (event.clientY - targetBox.top) / targetBox.height;
isSegmenting = true;
const { maskURL } = await getSegmentationMask({ x, y });
pointArr = [...pointArr, [ x, y , true ]]
const { maskURL } = await getSegmentationMask(pointArr);
isSegmenting = false;
drawMask(maskURL);
});

View File

@ -141,7 +141,7 @@ self.addEventListener("message", async (event) => {
}
self.postMessage({ status: "segmenting", message: "Segmenting" });
const { mask, image } = sam.mask_for_point(points.x, points.y);
const { mask, image } = sam.mask_for_point({ points });
const maskDataURL = await createImageCanvas(mask, image);
// Send the segment back to the main thread as JSON
self.postMessage({

View File

@ -74,17 +74,24 @@ impl Model {
Ok(())
}
// x and y have to be between 0 and 1
pub fn mask_for_point(&self, x: f64, y: f64) -> Result<JsValue, JsError> {
if !(0. ..=1.).contains(&x) {
Err(JsError::new(&format!(
"x has to be between 0 and 1, got {x}"
)))?
}
if !(0. ..=1.).contains(&y) {
Err(JsError::new(&format!(
"y has to be between 0 and 1, got {y}"
)))?
pub fn mask_for_point(&self, input: JsValue) -> Result<JsValue, JsError> {
let input: PointsInput =
serde_wasm_bindgen::from_value(input).map_err(|m| JsError::new(&m.to_string()))?;
let transformed_points = input.points;
for &(x, y, _bool) in &transformed_points {
if !(0.0..=1.0).contains(&x) {
return Err(JsError::new(&format!(
"x has to be between 0 and 1, got {}",
x
)));
}
if !(0.0..=1.0).contains(&y) {
return Err(JsError::new(&format!(
"y has to be between 0 and 1, got {}",
y
)));
}
}
let embeddings = match &self.embeddings {
None => Err(JsError::new("image embeddings have not been set"))?,
@ -94,7 +101,7 @@ impl Model {
&embeddings.data,
embeddings.height as usize,
embeddings.width as usize,
&[(x, y, true)],
&transformed_points,
false,
)?;
let iou = iou_predictions.flatten(0, 1)?.to_vec1::<f32>()?[0];
@ -134,6 +141,11 @@ struct MaskImage {
image: Image,
}
#[derive(serde::Serialize, serde::Deserialize)]
struct PointsInput {
points: Vec<(f64, f64, bool)>,
}
fn main() {
console_error_panic_hook::set_once();
}