mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
[segment-anything] add multi point logic for demo site (#1002)
* [segment-anything] add multi point logic for demo site * [segment-anything] remove libs and update functions
This commit is contained in:
@ -74,17 +74,24 @@ impl Model {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// x and y have to be between 0 and 1
|
||||
pub fn mask_for_point(&self, x: f64, y: f64) -> Result<JsValue, JsError> {
|
||||
if !(0. ..=1.).contains(&x) {
|
||||
Err(JsError::new(&format!(
|
||||
"x has to be between 0 and 1, got {x}"
|
||||
)))?
|
||||
}
|
||||
if !(0. ..=1.).contains(&y) {
|
||||
Err(JsError::new(&format!(
|
||||
"y has to be between 0 and 1, got {y}"
|
||||
)))?
|
||||
pub fn mask_for_point(&self, input: JsValue) -> Result<JsValue, JsError> {
|
||||
let input: PointsInput =
|
||||
serde_wasm_bindgen::from_value(input).map_err(|m| JsError::new(&m.to_string()))?;
|
||||
let transformed_points = input.points;
|
||||
|
||||
for &(x, y, _bool) in &transformed_points {
|
||||
if !(0.0..=1.0).contains(&x) {
|
||||
return Err(JsError::new(&format!(
|
||||
"x has to be between 0 and 1, got {}",
|
||||
x
|
||||
)));
|
||||
}
|
||||
if !(0.0..=1.0).contains(&y) {
|
||||
return Err(JsError::new(&format!(
|
||||
"y has to be between 0 and 1, got {}",
|
||||
y
|
||||
)));
|
||||
}
|
||||
}
|
||||
let embeddings = match &self.embeddings {
|
||||
None => Err(JsError::new("image embeddings have not been set"))?,
|
||||
@ -94,7 +101,7 @@ impl Model {
|
||||
&embeddings.data,
|
||||
embeddings.height as usize,
|
||||
embeddings.width as usize,
|
||||
&[(x, y, true)],
|
||||
&transformed_points,
|
||||
false,
|
||||
)?;
|
||||
let iou = iou_predictions.flatten(0, 1)?.to_vec1::<f32>()?[0];
|
||||
@ -134,6 +141,11 @@ struct MaskImage {
|
||||
image: Image,
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize, serde::Deserialize)]
|
||||
struct PointsInput {
|
||||
points: Vec<(f64, f64, bool)>,
|
||||
}
|
||||
|
||||
fn main() {
|
||||
console_error_panic_hook::set_once();
|
||||
}
|
||||
|
Reference in New Issue
Block a user