mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Return the low res mask in the wasm segment-anything module. (#798)
* Return the low res mask. * Add some validations.
This commit is contained in:
@ -136,13 +136,18 @@ impl Sam {
|
||||
let (_c, original_h, original_w) = img.dims3()?;
|
||||
let img = self.preprocess(img)?.unsqueeze(0)?;
|
||||
let img_embeddings = self.image_encoder.forward(&img)?;
|
||||
self.forward_for_embeddings(
|
||||
let (low_res_mask, iou) = self.forward_for_embeddings(
|
||||
&img_embeddings,
|
||||
original_h,
|
||||
original_w,
|
||||
point,
|
||||
multimask_output,
|
||||
)
|
||||
)?;
|
||||
let mask = low_res_mask
|
||||
.upsample_nearest2d(IMAGE_SIZE, IMAGE_SIZE)?
|
||||
.get(0)?
|
||||
.i((.., ..original_h, ..original_w))?;
|
||||
Ok((mask, iou))
|
||||
}
|
||||
|
||||
pub fn forward_for_embeddings(
|
||||
@ -168,18 +173,13 @@ impl Sam {
|
||||
let points = points.as_ref().map(|(x, y)| (x, y));
|
||||
let (sparse_prompt_embeddings, dense_prompt_embeddings) =
|
||||
self.prompt_encoder.forward(points, None, None)?;
|
||||
let (low_res_mask, iou_predictions) = self.mask_decoder.forward(
|
||||
self.mask_decoder.forward(
|
||||
img_embeddings,
|
||||
&image_pe,
|
||||
&sparse_prompt_embeddings,
|
||||
&dense_prompt_embeddings,
|
||||
multimask_output,
|
||||
)?;
|
||||
let mask = low_res_mask
|
||||
.upsample_nearest2d(IMAGE_SIZE, IMAGE_SIZE)?
|
||||
.get(0)?
|
||||
.i((.., ..original_h, ..original_w))?;
|
||||
Ok((mask, iou_predictions))
|
||||
)
|
||||
}
|
||||
|
||||
pub fn unpreprocess(&self, img: &Tensor) -> Result<Tensor> {
|
||||
|
@ -77,8 +77,18 @@ impl Model {
|
||||
|
||||
// x and y have to be between 0 and 1
|
||||
pub fn mask_for_point(&self, x: f64, y: f64) -> Result<String, JsError> {
|
||||
if !(0. ..=1.).contains(&x) {
|
||||
Err(JsError::new(&format!(
|
||||
"x has to be between 0 and 1, got {x}"
|
||||
)))?
|
||||
}
|
||||
if !(0. ..=1.).contains(&y) {
|
||||
Err(JsError::new(&format!(
|
||||
"y has to be between 0 and 1, got {y}"
|
||||
)))?
|
||||
}
|
||||
let embeddings = match &self.embeddings {
|
||||
None => todo!(),
|
||||
None => Err(JsError::new("image embeddings have not been set"))?,
|
||||
Some(embeddings) => embeddings,
|
||||
};
|
||||
let (mask, iou_predictions) = self.sam.forward_for_embeddings(
|
||||
|
Reference in New Issue
Block a user