Add a wasm module for the segment anything example. (#797)

This commit is contained in:
Laurent Mazare
2023-09-10 12:29:37 +01:00
committed by GitHub
parent 6c58fc59fd
commit 584171cae1
6 changed files with 189 additions and 3 deletions

View File

@ -122,6 +122,11 @@ impl Sam {
})
}
pub fn embeddings(&self, img: &Tensor) -> Result<Tensor> {
let img = self.preprocess(img)?.unsqueeze(0)?;
self.image_encoder.forward(&img)
}
pub fn forward(
&self,
img: &Tensor,
@ -131,15 +136,32 @@ impl Sam {
let (_c, original_h, original_w) = img.dims3()?;
let img = self.preprocess(img)?.unsqueeze(0)?;
let img_embeddings = self.image_encoder.forward(&img)?;
self.forward_for_embeddings(
&img_embeddings,
original_h,
original_w,
point,
multimask_output,
)
}
pub fn forward_for_embeddings(
&self,
img_embeddings: &Tensor,
original_h: usize,
original_w: usize,
point: Option<(f64, f64)>,
multimask_output: bool,
) -> Result<(Tensor, Tensor)> {
let image_pe = self.prompt_encoder.get_dense_pe()?;
let points = match point {
None => None,
Some((x, y)) => {
let points = Tensor::new(
&[[[x as f32 * original_w as f32, y as f32 * original_h as f32]]],
img.device(),
img_embeddings.device(),
)?;
let labels = Tensor::ones((1, 1), DType::F32, img.device())?;
let labels = Tensor::ones((1, 1), DType::F32, img_embeddings.device())?;
Some((points, labels))
}
};
@ -147,7 +169,7 @@ impl Sam {
let (sparse_prompt_embeddings, dense_prompt_embeddings) =
self.prompt_encoder.forward(points, None, None)?;
let (low_res_mask, iou_predictions) = self.mask_decoder.forward(
&img_embeddings,
img_embeddings,
&image_pe,
&sparse_prompt_embeddings,
&dense_prompt_embeddings,