diff --git a/candle-wasm-examples/segment-anything/lib-example.html b/candle-wasm-examples/segment-anything/lib-example.html index 5060f073..6684f1a1 100644 --- a/candle-wasm-examples/segment-anything/lib-example.html +++ b/candle-wasm-examples/segment-anything/lib-example.html @@ -33,6 +33,7 @@ url: "sam_vit_b_01ec64.safetensors", }, }; + let pointArr = [] const samWorker = new Worker("./samWorker.js", { type: "module" }); async function segmentPoints( @@ -145,6 +146,7 @@ //add event listener to clear button clearBtn.addEventListener("click", () => { cleanImageCanvas(); + pointArr = [] }); //add click event to canvas canvas.addEventListener("click", async (event) => { @@ -155,7 +157,8 @@ const x = (event.clientX - targetBox.left) / targetBox.width; const y = (event.clientY - targetBox.top) / targetBox.height; isSegmenting = true; - const { maskURL } = await getSegmentationMask({ x, y }); + pointArr = [...pointArr, [ x, y , true ]] + const { maskURL } = await getSegmentationMask(pointArr); isSegmenting = false; drawMask(maskURL); }); diff --git a/candle-wasm-examples/segment-anything/samWorker.js b/candle-wasm-examples/segment-anything/samWorker.js index c1a152ef..5d0a1b5c 100644 --- a/candle-wasm-examples/segment-anything/samWorker.js +++ b/candle-wasm-examples/segment-anything/samWorker.js @@ -141,7 +141,7 @@ self.addEventListener("message", async (event) => { } self.postMessage({ status: "segmenting", message: "Segmenting" }); - const { mask, image } = sam.mask_for_point(points.x, points.y); + const { mask, image } = sam.mask_for_point({ points }); const maskDataURL = await createImageCanvas(mask, image); // Send the segment back to the main thread as JSON self.postMessage({ diff --git a/candle-wasm-examples/segment-anything/src/bin/m.rs b/candle-wasm-examples/segment-anything/src/bin/m.rs index 12349493..2be59adc 100644 --- a/candle-wasm-examples/segment-anything/src/bin/m.rs +++ b/candle-wasm-examples/segment-anything/src/bin/m.rs @@ -74,17 +74,24 @@ impl Model { Ok(()) } - // x and y have to be between 0 and 1 - pub fn mask_for_point(&self, x: f64, y: f64) -> Result { - if !(0. ..=1.).contains(&x) { - Err(JsError::new(&format!( - "x has to be between 0 and 1, got {x}" - )))? - } - if !(0. ..=1.).contains(&y) { - Err(JsError::new(&format!( - "y has to be between 0 and 1, got {y}" - )))? + pub fn mask_for_point(&self, input: JsValue) -> Result { + let input: PointsInput = + serde_wasm_bindgen::from_value(input).map_err(|m| JsError::new(&m.to_string()))?; + let transformed_points = input.points; + + for &(x, y, _bool) in &transformed_points { + if !(0.0..=1.0).contains(&x) { + return Err(JsError::new(&format!( + "x has to be between 0 and 1, got {}", + x + ))); + } + if !(0.0..=1.0).contains(&y) { + return Err(JsError::new(&format!( + "y has to be between 0 and 1, got {}", + y + ))); + } } let embeddings = match &self.embeddings { None => Err(JsError::new("image embeddings have not been set"))?, @@ -94,7 +101,7 @@ impl Model { &embeddings.data, embeddings.height as usize, embeddings.width as usize, - &[(x, y, true)], + &transformed_points, false, )?; let iou = iou_predictions.flatten(0, 1)?.to_vec1::()?[0]; @@ -134,6 +141,11 @@ struct MaskImage { image: Image, } +#[derive(serde::Serialize, serde::Deserialize)] +struct PointsInput { + points: Vec<(f64, f64, bool)>, +} + fn main() { console_error_panic_hook::set_once(); }