mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
[segment-anything] add multi point logic for demo site (#1002)
* [segment-anything] add multi point logic for demo site * [segment-anything] remove libs and update functions
This commit is contained in:
@ -33,6 +33,7 @@
|
|||||||
url: "sam_vit_b_01ec64.safetensors",
|
url: "sam_vit_b_01ec64.safetensors",
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
let pointArr = []
|
||||||
const samWorker = new Worker("./samWorker.js", { type: "module" });
|
const samWorker = new Worker("./samWorker.js", { type: "module" });
|
||||||
|
|
||||||
async function segmentPoints(
|
async function segmentPoints(
|
||||||
@ -145,6 +146,7 @@
|
|||||||
//add event listener to clear button
|
//add event listener to clear button
|
||||||
clearBtn.addEventListener("click", () => {
|
clearBtn.addEventListener("click", () => {
|
||||||
cleanImageCanvas();
|
cleanImageCanvas();
|
||||||
|
pointArr = []
|
||||||
});
|
});
|
||||||
//add click event to canvas
|
//add click event to canvas
|
||||||
canvas.addEventListener("click", async (event) => {
|
canvas.addEventListener("click", async (event) => {
|
||||||
@ -155,7 +157,8 @@
|
|||||||
const x = (event.clientX - targetBox.left) / targetBox.width;
|
const x = (event.clientX - targetBox.left) / targetBox.width;
|
||||||
const y = (event.clientY - targetBox.top) / targetBox.height;
|
const y = (event.clientY - targetBox.top) / targetBox.height;
|
||||||
isSegmenting = true;
|
isSegmenting = true;
|
||||||
const { maskURL } = await getSegmentationMask({ x, y });
|
pointArr = [...pointArr, [ x, y , true ]]
|
||||||
|
const { maskURL } = await getSegmentationMask(pointArr);
|
||||||
isSegmenting = false;
|
isSegmenting = false;
|
||||||
drawMask(maskURL);
|
drawMask(maskURL);
|
||||||
});
|
});
|
||||||
|
@ -141,7 +141,7 @@ self.addEventListener("message", async (event) => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
self.postMessage({ status: "segmenting", message: "Segmenting" });
|
self.postMessage({ status: "segmenting", message: "Segmenting" });
|
||||||
const { mask, image } = sam.mask_for_point(points.x, points.y);
|
const { mask, image } = sam.mask_for_point({ points });
|
||||||
const maskDataURL = await createImageCanvas(mask, image);
|
const maskDataURL = await createImageCanvas(mask, image);
|
||||||
// Send the segment back to the main thread as JSON
|
// Send the segment back to the main thread as JSON
|
||||||
self.postMessage({
|
self.postMessage({
|
||||||
|
@ -74,17 +74,24 @@ impl Model {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
// x and y have to be between 0 and 1
|
pub fn mask_for_point(&self, input: JsValue) -> Result<JsValue, JsError> {
|
||||||
pub fn mask_for_point(&self, x: f64, y: f64) -> Result<JsValue, JsError> {
|
let input: PointsInput =
|
||||||
if !(0. ..=1.).contains(&x) {
|
serde_wasm_bindgen::from_value(input).map_err(|m| JsError::new(&m.to_string()))?;
|
||||||
Err(JsError::new(&format!(
|
let transformed_points = input.points;
|
||||||
"x has to be between 0 and 1, got {x}"
|
|
||||||
)))?
|
for &(x, y, _bool) in &transformed_points {
|
||||||
}
|
if !(0.0..=1.0).contains(&x) {
|
||||||
if !(0. ..=1.).contains(&y) {
|
return Err(JsError::new(&format!(
|
||||||
Err(JsError::new(&format!(
|
"x has to be between 0 and 1, got {}",
|
||||||
"y has to be between 0 and 1, got {y}"
|
x
|
||||||
)))?
|
)));
|
||||||
|
}
|
||||||
|
if !(0.0..=1.0).contains(&y) {
|
||||||
|
return Err(JsError::new(&format!(
|
||||||
|
"y has to be between 0 and 1, got {}",
|
||||||
|
y
|
||||||
|
)));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
let embeddings = match &self.embeddings {
|
let embeddings = match &self.embeddings {
|
||||||
None => Err(JsError::new("image embeddings have not been set"))?,
|
None => Err(JsError::new("image embeddings have not been set"))?,
|
||||||
@ -94,7 +101,7 @@ impl Model {
|
|||||||
&embeddings.data,
|
&embeddings.data,
|
||||||
embeddings.height as usize,
|
embeddings.height as usize,
|
||||||
embeddings.width as usize,
|
embeddings.width as usize,
|
||||||
&[(x, y, true)],
|
&transformed_points,
|
||||||
false,
|
false,
|
||||||
)?;
|
)?;
|
||||||
let iou = iou_predictions.flatten(0, 1)?.to_vec1::<f32>()?[0];
|
let iou = iou_predictions.flatten(0, 1)?.to_vec1::<f32>()?[0];
|
||||||
@ -134,6 +141,11 @@ struct MaskImage {
|
|||||||
image: Image,
|
image: Image,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(serde::Serialize, serde::Deserialize)]
|
||||||
|
struct PointsInput {
|
||||||
|
points: Vec<(f64, f64, bool)>,
|
||||||
|
}
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
console_error_panic_hook::set_once();
|
console_error_panic_hook::set_once();
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user