mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Add a wasm module for the segment anything example. (#797)
This commit is contained in:
29
candle-wasm-examples/segment-anything/Cargo.toml
Normal file
29
candle-wasm-examples/segment-anything/Cargo.toml
Normal file
@ -0,0 +1,29 @@
|
||||
[package]
|
||||
name = "candle-wasm-example-sam"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
description.workspace = true
|
||||
repository.workspace = true
|
||||
keywords.workspace = true
|
||||
categories.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../../candle-core", version = "0.2.1", package = "candle-core" }
|
||||
candle-nn = { path = "../../candle-nn", version = "0.2.1" }
|
||||
candle-transformers = { path = "../../candle-transformers", version = "0.2.1" }
|
||||
num-traits = { workspace = true }
|
||||
|
||||
# App crates.
|
||||
anyhow = { workspace = true }
|
||||
byteorder = { workspace = true }
|
||||
getrandom = { version = "0.2", features = ["js"] }
|
||||
image = { workspace = true }
|
||||
log = { workspace = true }
|
||||
safetensors = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
|
||||
# Wasm specific crates.
|
||||
console_error_panic_hook = "0.1.7"
|
||||
wasm-bindgen = "0.2.87"
|
2
candle-wasm-examples/segment-anything/build-lib.sh
Normal file
2
candle-wasm-examples/segment-anything/build-lib.sh
Normal file
@ -0,0 +1,2 @@
|
||||
cargo build --target wasm32-unknown-unknown --release
|
||||
wasm-bindgen ../../target/wasm32-unknown-unknown/release/m.wasm --out-dir build --target web
|
113
candle-wasm-examples/segment-anything/src/bin/m.rs
Normal file
113
candle-wasm-examples/segment-anything/src/bin/m.rs
Normal file
@ -0,0 +1,113 @@
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_wasm_example_sam as sam;
|
||||
use wasm_bindgen::prelude::*;
|
||||
|
||||
#[allow(unused)]
|
||||
struct Embeddings {
|
||||
original_width: u32,
|
||||
original_height: u32,
|
||||
width: u32,
|
||||
height: u32,
|
||||
data: Tensor,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
pub struct Model {
|
||||
sam: sam::Sam,
|
||||
embeddings: Option<Embeddings>,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl Model {
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(weights: &[u8], use_tiny: bool) -> Result<Model, JsError> {
|
||||
console_error_panic_hook::set_once();
|
||||
let dev = &Device::Cpu;
|
||||
let weights = safetensors::tensor::SafeTensors::deserialize(weights)?;
|
||||
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, dev);
|
||||
let sam = if use_tiny {
|
||||
sam::Sam::new_tiny(vb)? // tiny vit_t
|
||||
} else {
|
||||
sam::Sam::new(768, 12, 12, &[2, 5, 8, 11], vb)? // sam_vit_b
|
||||
};
|
||||
Ok(Self {
|
||||
sam,
|
||||
embeddings: None,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn set_image_embeddings(&mut self, image_data: Vec<u8>) -> Result<(), JsError> {
|
||||
sam::console_log!("image data: {}", image_data.len());
|
||||
let image_data = std::io::Cursor::new(image_data);
|
||||
let image = image::io::Reader::new(image_data)
|
||||
.with_guessed_format()?
|
||||
.decode()
|
||||
.map_err(candle::Error::wrap)?;
|
||||
let (original_height, original_width) = (image.height(), image.width());
|
||||
let (height, width) = (original_height, original_width);
|
||||
let resize_longest = sam::IMAGE_SIZE as u32;
|
||||
let (height, width) = if height < width {
|
||||
let h = (resize_longest * height) / width;
|
||||
(h, resize_longest)
|
||||
} else {
|
||||
let w = (resize_longest * width) / height;
|
||||
(resize_longest, w)
|
||||
};
|
||||
let image_t = {
|
||||
let img = image.resize_exact(width, height, image::imageops::FilterType::CatmullRom);
|
||||
let data = img.to_rgb8().into_raw();
|
||||
Tensor::from_vec(
|
||||
data,
|
||||
(img.height() as usize, img.width() as usize, 3),
|
||||
&Device::Cpu,
|
||||
)?
|
||||
.permute((2, 0, 1))?
|
||||
};
|
||||
let data = self.sam.embeddings(&image_t)?;
|
||||
self.embeddings = Some(Embeddings {
|
||||
original_width,
|
||||
original_height,
|
||||
width,
|
||||
height,
|
||||
data,
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// x and y have to be between 0 and 1
|
||||
pub fn mask_for_point(&self, x: f64, y: f64) -> Result<String, JsError> {
|
||||
let embeddings = match &self.embeddings {
|
||||
None => todo!(),
|
||||
Some(embeddings) => embeddings,
|
||||
};
|
||||
let (mask, iou_predictions) = self.sam.forward_for_embeddings(
|
||||
&embeddings.data,
|
||||
embeddings.height as usize,
|
||||
embeddings.width as usize,
|
||||
Some((x, y)),
|
||||
false,
|
||||
)?;
|
||||
let iou = iou_predictions.to_vec1::<f32>()?[0];
|
||||
let mask_shape = mask.dims().to_vec();
|
||||
let mask_data = mask.ge(0f32)?.flatten_all()?.to_vec1::<u8>()?;
|
||||
let mask = Mask {
|
||||
iou,
|
||||
mask_shape,
|
||||
mask_data,
|
||||
};
|
||||
let json = serde_json::to_string(&mask)?;
|
||||
Ok(json)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize, serde::Deserialize)]
|
||||
struct Mask {
|
||||
iou: f32,
|
||||
mask_shape: Vec<usize>,
|
||||
mask_data: Vec<u8>,
|
||||
}
|
||||
|
||||
fn main() {
|
||||
console_error_panic_hook::set_once();
|
||||
}
|
19
candle-wasm-examples/segment-anything/src/lib.rs
Normal file
19
candle-wasm-examples/segment-anything/src/lib.rs
Normal file
@ -0,0 +1,19 @@
|
||||
use candle_transformers::models::segment_anything::sam;
|
||||
use wasm_bindgen::prelude::*;
|
||||
|
||||
pub use sam::{Sam, IMAGE_SIZE};
|
||||
|
||||
#[wasm_bindgen]
|
||||
extern "C" {
|
||||
// Use `js_namespace` here to bind `console.log(..)` instead of just
|
||||
// `log(..)`
|
||||
#[wasm_bindgen(js_namespace = console)]
|
||||
pub fn log(s: &str);
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! console_log {
|
||||
// Note that this is using the `log` function imported above during
|
||||
// `bare_bones`
|
||||
($($t:tt)*) => ($crate::log(&format_args!($($t)*).to_string()))
|
||||
}
|
Reference in New Issue
Block a user