mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Add a wasm module for the segment anything example. (#797)
This commit is contained in:
@ -122,6 +122,11 @@ impl Sam {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn embeddings(&self, img: &Tensor) -> Result<Tensor> {
|
||||
let img = self.preprocess(img)?.unsqueeze(0)?;
|
||||
self.image_encoder.forward(&img)
|
||||
}
|
||||
|
||||
pub fn forward(
|
||||
&self,
|
||||
img: &Tensor,
|
||||
@ -131,15 +136,32 @@ impl Sam {
|
||||
let (_c, original_h, original_w) = img.dims3()?;
|
||||
let img = self.preprocess(img)?.unsqueeze(0)?;
|
||||
let img_embeddings = self.image_encoder.forward(&img)?;
|
||||
self.forward_for_embeddings(
|
||||
&img_embeddings,
|
||||
original_h,
|
||||
original_w,
|
||||
point,
|
||||
multimask_output,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn forward_for_embeddings(
|
||||
&self,
|
||||
img_embeddings: &Tensor,
|
||||
original_h: usize,
|
||||
original_w: usize,
|
||||
point: Option<(f64, f64)>,
|
||||
multimask_output: bool,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
let image_pe = self.prompt_encoder.get_dense_pe()?;
|
||||
let points = match point {
|
||||
None => None,
|
||||
Some((x, y)) => {
|
||||
let points = Tensor::new(
|
||||
&[[[x as f32 * original_w as f32, y as f32 * original_h as f32]]],
|
||||
img.device(),
|
||||
img_embeddings.device(),
|
||||
)?;
|
||||
let labels = Tensor::ones((1, 1), DType::F32, img.device())?;
|
||||
let labels = Tensor::ones((1, 1), DType::F32, img_embeddings.device())?;
|
||||
Some((points, labels))
|
||||
}
|
||||
};
|
||||
@ -147,7 +169,7 @@ impl Sam {
|
||||
let (sparse_prompt_embeddings, dense_prompt_embeddings) =
|
||||
self.prompt_encoder.forward(points, None, None)?;
|
||||
let (low_res_mask, iou_predictions) = self.mask_decoder.forward(
|
||||
&img_embeddings,
|
||||
img_embeddings,
|
||||
&image_pe,
|
||||
&sparse_prompt_embeddings,
|
||||
&dense_prompt_embeddings,
|
||||
|
Reference in New Issue
Block a user