Segment Anything - process images (#766)

* Start processing images.

* Add LayerNorm2d.

* Properly use LayerNorm2d.

* Tweak eps.

* Use LayerNorm on inputs with a rank different from 3.

* Window partitioning.

* Fix a couple todos.

* More todos.

* Hard-code the einsums.

* More padding support.

* Some sizes tweaks.

* Use the hub to get the weights.

* Use a batch matmul.

* Tweaks.

* More fixes.

* Get some predictions to be generated.
This commit is contained in:
Laurent Mazare
2023-09-07 19:22:45 +01:00
committed by GitHub
parent 7b50f3e106
commit 7396b8ed1a
10 changed files with 303 additions and 105 deletions

View File

@ -1,5 +1,5 @@
use candle::{DType, IndexOp, Result, Tensor, D};
use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
use candle_nn::{Linear, Module, VarBuilder};
#[derive(Debug)]
struct PostionEmbeddingRandom {
@ -17,7 +17,7 @@ impl PostionEmbeddingRandom {
fn pe_encoding(&self, coords: &Tensor) -> Result<Tensor> {
let coords = coords.affine(2., -1.)?;
let coords = coords.matmul(&self.positional_encoding_gaussian_matrix)?;
let coords = coords.broadcast_matmul(&self.positional_encoding_gaussian_matrix)?;
let coords = (coords * (2. * std::f64::consts::PI))?;
Tensor::cat(&[coords.sin()?, coords.cos()?], D::Minus1)
}
@ -25,12 +25,14 @@ impl PostionEmbeddingRandom {
fn forward(&self, h: usize, w: usize) -> Result<Tensor> {
let device = self.positional_encoding_gaussian_matrix.device();
let grid = Tensor::ones((h, w), DType::F32, device)?;
// TODO: cumsum
let x_embed = (&grid - 0.5)?;
// TODO: cumsum
let y_embed = (&grid - 0.5)?;
let x_embed = (x_embed / w as f64)?;
let y_embed = (y_embed / h as f64)?;
let x_embed = (Tensor::arange(0u32, w as u32, device)?.to_dtype(DType::F32)? + 0.5)?;
let y_embed = (Tensor::arange(0u32, h as u32, device)?.to_dtype(DType::F32)? + 0.5)?;
let x_embed = (x_embed / w as f64)?
.reshape((1, w))?
.broadcast_as((h, w))?;
let y_embed = (y_embed / h as f64)?
.reshape((h, 1))?
.broadcast_as((h, w))?;
let coords = Tensor::stack(&[&x_embed, &y_embed], D::Minus1)?;
self.pe_encoding(&coords)?.permute((2, 0, 1))
}
@ -55,13 +57,14 @@ pub struct PromptEncoder {
point_embeddings: Vec<candle_nn::Embedding>,
not_a_point_embed: candle_nn::Embedding,
mask_downscaling_conv1: candle_nn::Conv2d,
mask_downscaling_ln1: LayerNorm,
mask_downscaling_ln1: crate::LayerNorm2d,
mask_downscaling_conv2: candle_nn::Conv2d,
mask_downscaling_ln2: LayerNorm,
mask_downscaling_ln2: crate::LayerNorm2d,
mask_downscaling_conv3: candle_nn::Conv2d,
no_mask_embed: candle_nn::Embedding,
image_embedding_size: (usize, usize),
input_image_size: (usize, usize),
embed_dim: usize,
}
impl PromptEncoder {
@ -97,8 +100,9 @@ impl PromptEncoder {
vb.pp("mask_downscaling.6"),
)?;
let mask_downscaling_ln1 =
layer_norm(mask_in_chans / 4, 1e-6, vb.pp("mask_downscaling.1"))?;
let mask_downscaling_ln2 = layer_norm(mask_in_chans, 1e-6, vb.pp("mask_downscaling.4"))?;
crate::LayerNorm2d::new(mask_in_chans / 4, 1e-6, vb.pp("mask_downscaling.1"))?;
let mask_downscaling_ln2 =
crate::LayerNorm2d::new(mask_in_chans, 1e-6, vb.pp("mask_downscaling.4"))?;
let mut point_embeddings = Vec::with_capacity(num_points_embeddings);
let vb_e = vb.pp("point_embeddings");
for i in 0..num_points_embeddings {
@ -117,9 +121,16 @@ impl PromptEncoder {
no_mask_embed,
image_embedding_size,
input_image_size,
embed_dim,
})
}
pub fn get_dense_pe(&self) -> Result<Tensor> {
self.pe_layer
.forward(self.image_embedding_size.0, self.image_embedding_size.1)?
.unsqueeze(0)
}
fn embed_masks(&self, masks: &Tensor) -> Result<Tensor> {
masks
.apply(&self.mask_downscaling_conv1)?
@ -133,7 +144,16 @@ impl PromptEncoder {
fn embed_points(&self, points: &Tensor, labels: &Tensor, pad: bool) -> Result<Tensor> {
let points = (points + 0.5)?;
let points = if pad { todo!() } else { points };
let dev = points.device();
let (points, labels) = if pad {
let padding_point = Tensor::zeros((points.dim(0)?, 1, 2), DType::F32, dev)?;
let padding_label = (Tensor::ones((labels.dim(0)?, 1), DType::F32, dev)? * (-1f64))?;
let points = Tensor::cat(&[&points, &padding_point], 1)?;
let labels = Tensor::cat(&[labels, &padding_label], 1)?;
(points, labels)
} else {
(points, labels.clone())
};
let point_embedding = self
.pe_layer
.forward_with_coords(&points, self.input_image_size)?;
@ -154,7 +174,7 @@ impl PromptEncoder {
Tensor::cat(&[&ce1, &ce2], 1)
}
fn forward(
pub fn forward(
&self,
points: Option<(&Tensor, &Tensor)>,
boxes: Option<&Tensor>,
@ -172,7 +192,9 @@ impl PromptEncoder {
(Some(se_points), Some(se_boxes)) => Tensor::cat(&[se_points, se_boxes], 1)?,
(Some(se_points), None) => se_points,
(None, Some(se_boxes)) => se_boxes,
(None, None) => Tensor::zeros(1, DType::F32, &candle::Device::Cpu)?,
(None, None) => {
Tensor::zeros((1, 0, self.embed_dim), DType::F32, &candle::Device::Cpu)?
}
};
let dense_embeddings = match masks {
@ -180,7 +202,7 @@ impl PromptEncoder {
let emb = self.no_mask_embed.embeddings();
emb.reshape((1, emb.elem_count(), 1, 1))?.expand((
1,
0,
emb.elem_count(),
self.image_embedding_size.0,
self.image_embedding_size.1,
))?