Files
candle/candle-examples/examples/segment-anything/model_prompt_encoder.rs
Laurent Mazare 7396b8ed1a Segment Anything - process images (#766)
* Start processing images.

* Add LayerNorm2d.

* Properly use LayerNorm2d.

* Tweak eps.

* Use LayerNorm on inputs with a rank different from 3.

* Window partitioning.

* Fix a couple todos.

* More todos.

* Hard-code the einsums.

* More padding support.

* Some sizes tweaks.

* Use the hub to get the weights.

* Use a batch matmul.

* Tweaks.

* More fixes.

* Get some predictions to be generated.
2023-09-07 19:22:45 +01:00

215 lines
8.0 KiB
Rust

use candle::{DType, IndexOp, Result, Tensor, D};
use candle_nn::{Linear, Module, VarBuilder};
#[derive(Debug)]
struct PostionEmbeddingRandom {
positional_encoding_gaussian_matrix: Tensor,
}
impl PostionEmbeddingRandom {
fn new(num_pos_feats: usize, vb: VarBuilder) -> Result<Self> {
let positional_encoding_gaussian_matrix =
vb.get((2, num_pos_feats), "positional_encoding_gaussian_matrix")?;
Ok(Self {
positional_encoding_gaussian_matrix,
})
}
fn pe_encoding(&self, coords: &Tensor) -> Result<Tensor> {
let coords = coords.affine(2., -1.)?;
let coords = coords.broadcast_matmul(&self.positional_encoding_gaussian_matrix)?;
let coords = (coords * (2. * std::f64::consts::PI))?;
Tensor::cat(&[coords.sin()?, coords.cos()?], D::Minus1)
}
fn forward(&self, h: usize, w: usize) -> Result<Tensor> {
let device = self.positional_encoding_gaussian_matrix.device();
let grid = Tensor::ones((h, w), DType::F32, device)?;
let x_embed = (Tensor::arange(0u32, w as u32, device)?.to_dtype(DType::F32)? + 0.5)?;
let y_embed = (Tensor::arange(0u32, h as u32, device)?.to_dtype(DType::F32)? + 0.5)?;
let x_embed = (x_embed / w as f64)?
.reshape((1, w))?
.broadcast_as((h, w))?;
let y_embed = (y_embed / h as f64)?
.reshape((h, 1))?
.broadcast_as((h, w))?;
let coords = Tensor::stack(&[&x_embed, &y_embed], D::Minus1)?;
self.pe_encoding(&coords)?.permute((2, 0, 1))
}
fn forward_with_coords(
&self,
coords_input: &Tensor,
image_size: (usize, usize),
) -> Result<Tensor> {
let coords0 = (coords_input.narrow(D::Minus1, 0, 1)? / image_size.1 as f64)?;
let coords1 = (coords_input.narrow(D::Minus1, 1, 1)? / image_size.0 as f64)?;
let c = coords_input.dim(D::Minus1)?;
let coords_rest = coords_input.narrow(D::Minus1, 2, c - 2)?;
let coords = Tensor::cat(&[&coords0, &coords1, &coords_rest], D::Minus1)?;
self.pe_encoding(&coords)
}
}
#[derive(Debug)]
pub struct PromptEncoder {
pe_layer: PostionEmbeddingRandom,
point_embeddings: Vec<candle_nn::Embedding>,
not_a_point_embed: candle_nn::Embedding,
mask_downscaling_conv1: candle_nn::Conv2d,
mask_downscaling_ln1: crate::LayerNorm2d,
mask_downscaling_conv2: candle_nn::Conv2d,
mask_downscaling_ln2: crate::LayerNorm2d,
mask_downscaling_conv3: candle_nn::Conv2d,
no_mask_embed: candle_nn::Embedding,
image_embedding_size: (usize, usize),
input_image_size: (usize, usize),
embed_dim: usize,
}
impl PromptEncoder {
pub fn new(
embed_dim: usize,
image_embedding_size: (usize, usize),
input_image_size: (usize, usize),
mask_in_chans: usize,
vb: VarBuilder,
) -> Result<Self> {
let num_points_embeddings = 4;
let pe_layer = PostionEmbeddingRandom::new(embed_dim / 2, vb.pp("pe_layer"))?;
let not_a_point_embed = candle_nn::embedding(1, embed_dim, vb.pp("not_a_point_embed"))?;
let no_mask_embed = candle_nn::embedding(1, embed_dim, vb.pp("no_mask_embed"))?;
let cfg = candle_nn::Conv2dConfig {
stride: 2,
..Default::default()
};
let mask_downscaling_conv1 =
candle_nn::conv2d(1, mask_in_chans / 4, 2, cfg, vb.pp("mask_downscaling.0"))?;
let mask_downscaling_conv2 = candle_nn::conv2d(
mask_in_chans / 4,
mask_in_chans,
2,
cfg,
vb.pp("mask_downscaling.3"),
)?;
let mask_downscaling_conv3 = candle_nn::conv2d(
mask_in_chans,
embed_dim,
1,
Default::default(),
vb.pp("mask_downscaling.6"),
)?;
let mask_downscaling_ln1 =
crate::LayerNorm2d::new(mask_in_chans / 4, 1e-6, vb.pp("mask_downscaling.1"))?;
let mask_downscaling_ln2 =
crate::LayerNorm2d::new(mask_in_chans, 1e-6, vb.pp("mask_downscaling.4"))?;
let mut point_embeddings = Vec::with_capacity(num_points_embeddings);
let vb_e = vb.pp("point_embeddings");
for i in 0..num_points_embeddings {
let emb = candle_nn::embedding(1, embed_dim, vb_e.pp(i))?;
point_embeddings.push(emb)
}
Ok(Self {
pe_layer,
point_embeddings,
not_a_point_embed,
mask_downscaling_conv1,
mask_downscaling_ln1,
mask_downscaling_conv2,
mask_downscaling_ln2,
mask_downscaling_conv3,
no_mask_embed,
image_embedding_size,
input_image_size,
embed_dim,
})
}
pub fn get_dense_pe(&self) -> Result<Tensor> {
self.pe_layer
.forward(self.image_embedding_size.0, self.image_embedding_size.1)?
.unsqueeze(0)
}
fn embed_masks(&self, masks: &Tensor) -> Result<Tensor> {
masks
.apply(&self.mask_downscaling_conv1)?
.apply(&self.mask_downscaling_ln1)?
.gelu()?
.apply(&self.mask_downscaling_conv2)?
.apply(&self.mask_downscaling_ln2)?
.gelu()?
.apply(&self.mask_downscaling_conv3)
}
fn embed_points(&self, points: &Tensor, labels: &Tensor, pad: bool) -> Result<Tensor> {
let points = (points + 0.5)?;
let dev = points.device();
let (points, labels) = if pad {
let padding_point = Tensor::zeros((points.dim(0)?, 1, 2), DType::F32, dev)?;
let padding_label = (Tensor::ones((labels.dim(0)?, 1), DType::F32, dev)? * (-1f64))?;
let points = Tensor::cat(&[&points, &padding_point], 1)?;
let labels = Tensor::cat(&[labels, &padding_label], 1)?;
(points, labels)
} else {
(points, labels.clone())
};
let point_embedding = self
.pe_layer
.forward_with_coords(&points, self.input_image_size)?;
// TODO: tweak based on labels.
Ok(point_embedding)
}
fn embed_boxes(&self, boxes: &Tensor) -> Result<Tensor> {
let boxes = (boxes + 0.5)?;
let coords = boxes.reshape((boxes.elem_count() / 4, 2, 2))?;
let corner_embedding = self
.pe_layer
.forward_with_coords(&coords, self.input_image_size)?;
let ce1 = corner_embedding.i((.., 0))?;
let ce2 = corner_embedding.i((.., 1))?;
let ce1 = (ce1 + self.point_embeddings[2].embeddings())?;
let ce2 = (ce2 + self.point_embeddings[3].embeddings())?;
Tensor::cat(&[&ce1, &ce2], 1)
}
pub fn forward(
&self,
points: Option<(&Tensor, &Tensor)>,
boxes: Option<&Tensor>,
masks: Option<&Tensor>,
) -> Result<(Tensor, Tensor)> {
let se_points = match points {
Some((coords, labels)) => Some(self.embed_points(coords, labels, boxes.is_none())?),
None => None,
};
let se_boxes = match boxes {
Some(boxes) => Some(self.embed_boxes(boxes)?),
None => None,
};
let sparse_embeddings = match (se_points, se_boxes) {
(Some(se_points), Some(se_boxes)) => Tensor::cat(&[se_points, se_boxes], 1)?,
(Some(se_points), None) => se_points,
(None, Some(se_boxes)) => se_boxes,
(None, None) => {
Tensor::zeros((1, 0, self.embed_dim), DType::F32, &candle::Device::Cpu)?
}
};
let dense_embeddings = match masks {
None => {
let emb = self.no_mask_embed.embeddings();
emb.reshape((1, emb.elem_count(), 1, 1))?.expand((
1,
emb.elem_count(),
self.image_embedding_size.0,
self.image_embedding_size.1,
))?
}
Some(masks) => self.embed_masks(masks)?,
};
Ok((sparse_embeddings, dense_embeddings))
}
}