mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Segment Anything - process images (#766)
* Start processing images. * Add LayerNorm2d. * Properly use LayerNorm2d. * Tweak eps. * Use LayerNorm on inputs with a rank different from 3. * Window partitioning. * Fix a couple todos. * More todos. * Hard-code the einsums. * More padding support. * Some sizes tweaks. * Use the hub to get the weights. * Use a batch matmul. * Tweaks. * More fixes. * Get some predictions to be generated.
This commit is contained in:
@ -5,6 +5,10 @@ use crate::model_image_encoder::ImageEncoderViT;
|
||||
use crate::model_mask_decoder::MaskDecoder;
|
||||
use crate::model_prompt_encoder::PromptEncoder;
|
||||
|
||||
const PROMPT_EMBED_DIM: usize = 256;
|
||||
const IMAGE_SIZE: usize = 1024;
|
||||
const VIT_PATCH_SIZE: usize = 16;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Sam {
|
||||
image_encoder: ImageEncoderViT,
|
||||
@ -22,10 +26,6 @@ impl Sam {
|
||||
encoder_global_attn_indexes: &[usize],
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
const PROMPT_EMBED_DIM: usize = 256;
|
||||
const IMAGE_SIZE: usize = 1024;
|
||||
const VIT_PATCH_SIZE: usize = 16;
|
||||
|
||||
let image_embedding_size = IMAGE_SIZE / VIT_PATCH_SIZE;
|
||||
|
||||
let image_encoder = ImageEncoderViT::new(
|
||||
@ -69,4 +69,33 @@ impl Sam {
|
||||
pixel_mean,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, img: &Tensor, multimask_output: bool) -> Result<(Tensor, Tensor)> {
|
||||
let img = self.preprocess(img)?.unsqueeze(0)?;
|
||||
let img_embeddings = self.image_encoder.forward(&img)?;
|
||||
let image_pe = self.prompt_encoder.get_dense_pe()?;
|
||||
let (sparse_prompt_embeddings, dense_prompt_embeddings) =
|
||||
self.prompt_encoder.forward(None, None, None)?;
|
||||
let (low_res_mask, iou_predictions) = self.mask_decoder.forward(
|
||||
&img_embeddings,
|
||||
&image_pe,
|
||||
&sparse_prompt_embeddings,
|
||||
&dense_prompt_embeddings,
|
||||
multimask_output,
|
||||
)?;
|
||||
// TODO: post-processing.
|
||||
Ok((low_res_mask, iou_predictions))
|
||||
}
|
||||
|
||||
fn preprocess(&self, img: &Tensor) -> Result<Tensor> {
|
||||
let (c, h, w) = img.dims3()?;
|
||||
let img = img
|
||||
.broadcast_sub(&self.pixel_mean)?
|
||||
.broadcast_div(&self.pixel_std)?;
|
||||
if h > IMAGE_SIZE || w > IMAGE_SIZE {
|
||||
candle::bail!("image is too large ({w}, {h}), maximum size {IMAGE_SIZE}")
|
||||
}
|
||||
let img = img.pad_with_zeros(1, 0, IMAGE_SIZE - h)?;
|
||||
img.pad_with_zeros(2, 0, IMAGE_SIZE - w)
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user