mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Segment Anything - process images (#766)
* Start processing images. * Add LayerNorm2d. * Properly use LayerNorm2d. * Tweak eps. * Use LayerNorm on inputs with a rank different from 3. * Window partitioning. * Fix a couple todos. * More todos. * Hard-code the einsums. * More padding support. * Some sizes tweaks. * Use the hub to get the weights. * Use a batch matmul. * Tweaks. * More fixes. * Get some predictions to be generated.
This commit is contained in:
@ -1,5 +1,5 @@
|
||||
use candle::{DType, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
|
||||
use candle_nn::{Linear, Module, VarBuilder};
|
||||
|
||||
#[derive(Debug)]
|
||||
struct PostionEmbeddingRandom {
|
||||
@ -17,7 +17,7 @@ impl PostionEmbeddingRandom {
|
||||
|
||||
fn pe_encoding(&self, coords: &Tensor) -> Result<Tensor> {
|
||||
let coords = coords.affine(2., -1.)?;
|
||||
let coords = coords.matmul(&self.positional_encoding_gaussian_matrix)?;
|
||||
let coords = coords.broadcast_matmul(&self.positional_encoding_gaussian_matrix)?;
|
||||
let coords = (coords * (2. * std::f64::consts::PI))?;
|
||||
Tensor::cat(&[coords.sin()?, coords.cos()?], D::Minus1)
|
||||
}
|
||||
@ -25,12 +25,14 @@ impl PostionEmbeddingRandom {
|
||||
fn forward(&self, h: usize, w: usize) -> Result<Tensor> {
|
||||
let device = self.positional_encoding_gaussian_matrix.device();
|
||||
let grid = Tensor::ones((h, w), DType::F32, device)?;
|
||||
// TODO: cumsum
|
||||
let x_embed = (&grid - 0.5)?;
|
||||
// TODO: cumsum
|
||||
let y_embed = (&grid - 0.5)?;
|
||||
let x_embed = (x_embed / w as f64)?;
|
||||
let y_embed = (y_embed / h as f64)?;
|
||||
let x_embed = (Tensor::arange(0u32, w as u32, device)?.to_dtype(DType::F32)? + 0.5)?;
|
||||
let y_embed = (Tensor::arange(0u32, h as u32, device)?.to_dtype(DType::F32)? + 0.5)?;
|
||||
let x_embed = (x_embed / w as f64)?
|
||||
.reshape((1, w))?
|
||||
.broadcast_as((h, w))?;
|
||||
let y_embed = (y_embed / h as f64)?
|
||||
.reshape((h, 1))?
|
||||
.broadcast_as((h, w))?;
|
||||
let coords = Tensor::stack(&[&x_embed, &y_embed], D::Minus1)?;
|
||||
self.pe_encoding(&coords)?.permute((2, 0, 1))
|
||||
}
|
||||
@ -55,13 +57,14 @@ pub struct PromptEncoder {
|
||||
point_embeddings: Vec<candle_nn::Embedding>,
|
||||
not_a_point_embed: candle_nn::Embedding,
|
||||
mask_downscaling_conv1: candle_nn::Conv2d,
|
||||
mask_downscaling_ln1: LayerNorm,
|
||||
mask_downscaling_ln1: crate::LayerNorm2d,
|
||||
mask_downscaling_conv2: candle_nn::Conv2d,
|
||||
mask_downscaling_ln2: LayerNorm,
|
||||
mask_downscaling_ln2: crate::LayerNorm2d,
|
||||
mask_downscaling_conv3: candle_nn::Conv2d,
|
||||
no_mask_embed: candle_nn::Embedding,
|
||||
image_embedding_size: (usize, usize),
|
||||
input_image_size: (usize, usize),
|
||||
embed_dim: usize,
|
||||
}
|
||||
|
||||
impl PromptEncoder {
|
||||
@ -97,8 +100,9 @@ impl PromptEncoder {
|
||||
vb.pp("mask_downscaling.6"),
|
||||
)?;
|
||||
let mask_downscaling_ln1 =
|
||||
layer_norm(mask_in_chans / 4, 1e-6, vb.pp("mask_downscaling.1"))?;
|
||||
let mask_downscaling_ln2 = layer_norm(mask_in_chans, 1e-6, vb.pp("mask_downscaling.4"))?;
|
||||
crate::LayerNorm2d::new(mask_in_chans / 4, 1e-6, vb.pp("mask_downscaling.1"))?;
|
||||
let mask_downscaling_ln2 =
|
||||
crate::LayerNorm2d::new(mask_in_chans, 1e-6, vb.pp("mask_downscaling.4"))?;
|
||||
let mut point_embeddings = Vec::with_capacity(num_points_embeddings);
|
||||
let vb_e = vb.pp("point_embeddings");
|
||||
for i in 0..num_points_embeddings {
|
||||
@ -117,9 +121,16 @@ impl PromptEncoder {
|
||||
no_mask_embed,
|
||||
image_embedding_size,
|
||||
input_image_size,
|
||||
embed_dim,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn get_dense_pe(&self) -> Result<Tensor> {
|
||||
self.pe_layer
|
||||
.forward(self.image_embedding_size.0, self.image_embedding_size.1)?
|
||||
.unsqueeze(0)
|
||||
}
|
||||
|
||||
fn embed_masks(&self, masks: &Tensor) -> Result<Tensor> {
|
||||
masks
|
||||
.apply(&self.mask_downscaling_conv1)?
|
||||
@ -133,7 +144,16 @@ impl PromptEncoder {
|
||||
|
||||
fn embed_points(&self, points: &Tensor, labels: &Tensor, pad: bool) -> Result<Tensor> {
|
||||
let points = (points + 0.5)?;
|
||||
let points = if pad { todo!() } else { points };
|
||||
let dev = points.device();
|
||||
let (points, labels) = if pad {
|
||||
let padding_point = Tensor::zeros((points.dim(0)?, 1, 2), DType::F32, dev)?;
|
||||
let padding_label = (Tensor::ones((labels.dim(0)?, 1), DType::F32, dev)? * (-1f64))?;
|
||||
let points = Tensor::cat(&[&points, &padding_point], 1)?;
|
||||
let labels = Tensor::cat(&[labels, &padding_label], 1)?;
|
||||
(points, labels)
|
||||
} else {
|
||||
(points, labels.clone())
|
||||
};
|
||||
let point_embedding = self
|
||||
.pe_layer
|
||||
.forward_with_coords(&points, self.input_image_size)?;
|
||||
@ -154,7 +174,7 @@ impl PromptEncoder {
|
||||
Tensor::cat(&[&ce1, &ce2], 1)
|
||||
}
|
||||
|
||||
fn forward(
|
||||
pub fn forward(
|
||||
&self,
|
||||
points: Option<(&Tensor, &Tensor)>,
|
||||
boxes: Option<&Tensor>,
|
||||
@ -172,7 +192,9 @@ impl PromptEncoder {
|
||||
(Some(se_points), Some(se_boxes)) => Tensor::cat(&[se_points, se_boxes], 1)?,
|
||||
(Some(se_points), None) => se_points,
|
||||
(None, Some(se_boxes)) => se_boxes,
|
||||
(None, None) => Tensor::zeros(1, DType::F32, &candle::Device::Cpu)?,
|
||||
(None, None) => {
|
||||
Tensor::zeros((1, 0, self.embed_dim), DType::F32, &candle::Device::Cpu)?
|
||||
}
|
||||
};
|
||||
|
||||
let dense_embeddings = match masks {
|
||||
@ -180,7 +202,7 @@ impl PromptEncoder {
|
||||
let emb = self.no_mask_embed.embeddings();
|
||||
emb.reshape((1, emb.elem_count(), 1, 1))?.expand((
|
||||
1,
|
||||
0,
|
||||
emb.elem_count(),
|
||||
self.image_embedding_size.0,
|
||||
self.image_embedding_size.1,
|
||||
))?
|
||||
|
Reference in New Issue
Block a user