Return the low res mask in the wasm segment-anything module. (#798)

* Return the low res mask.

* Add some validations.
This commit is contained in:
Laurent Mazare
2023-09-10 13:03:02 +01:00
committed by GitHub
parent 584171cae1
commit 90e077e409
2 changed files with 20 additions and 10 deletions

View File

@ -136,13 +136,18 @@ impl Sam {
let (_c, original_h, original_w) = img.dims3()?;
let img = self.preprocess(img)?.unsqueeze(0)?;
let img_embeddings = self.image_encoder.forward(&img)?;
self.forward_for_embeddings(
let (low_res_mask, iou) = self.forward_for_embeddings(
&img_embeddings,
original_h,
original_w,
point,
multimask_output,
)
)?;
let mask = low_res_mask
.upsample_nearest2d(IMAGE_SIZE, IMAGE_SIZE)?
.get(0)?
.i((.., ..original_h, ..original_w))?;
Ok((mask, iou))
}
pub fn forward_for_embeddings(
@ -168,18 +173,13 @@ impl Sam {
let points = points.as_ref().map(|(x, y)| (x, y));
let (sparse_prompt_embeddings, dense_prompt_embeddings) =
self.prompt_encoder.forward(points, None, None)?;
let (low_res_mask, iou_predictions) = self.mask_decoder.forward(
self.mask_decoder.forward(
img_embeddings,
&image_pe,
&sparse_prompt_embeddings,
&dense_prompt_embeddings,
multimask_output,
)?;
let mask = low_res_mask
.upsample_nearest2d(IMAGE_SIZE, IMAGE_SIZE)?
.get(0)?
.i((.., ..original_h, ..original_w))?;
Ok((mask, iou_predictions))
)
}
pub fn unpreprocess(&self, img: &Tensor) -> Result<Tensor> {

View File

@ -77,8 +77,18 @@ impl Model {
// x and y have to be between 0 and 1
pub fn mask_for_point(&self, x: f64, y: f64) -> Result<String, JsError> {
if !(0. ..=1.).contains(&x) {
Err(JsError::new(&format!(
"x has to be between 0 and 1, got {x}"
)))?
}
if !(0. ..=1.).contains(&y) {
Err(JsError::new(&format!(
"y has to be between 0 and 1, got {y}"
)))?
}
let embeddings = match &self.embeddings {
None => todo!(),
None => Err(JsError::new("image embeddings have not been set"))?,
Some(embeddings) => embeddings,
};
let (mask, iou_predictions) = self.sam.forward_for_embeddings(