mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Segment Anything - process images (#766)
* Start processing images. * Add LayerNorm2d. * Properly use LayerNorm2d. * Tweak eps. * Use LayerNorm on inputs with a rank different from 3. * Window partitioning. * Fix a couple todos. * More todos. * Hard-code the einsums. * More padding support. * Some sizes tweaks. * Use the hub to get the weights. * Use a batch matmul. * Tweaks. * More fixes. * Get some predictions to be generated.
This commit is contained in:
@ -30,7 +30,7 @@ pub enum Error {
|
||||
UnsupportedDTypeForOp(DType, &'static str),
|
||||
|
||||
// === Dimension Index Errors ===
|
||||
#[error("{op}: dimension index {dim} out of range for {shape:?}")]
|
||||
#[error("{op}: dimension index {dim} out of range for shape {shape:?}")]
|
||||
DimOutOfRange {
|
||||
shape: Shape,
|
||||
dim: i32,
|
||||
|
@ -73,6 +73,14 @@ impl From<(usize, usize, usize, usize, usize)> for Shape {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<(usize, usize, usize, usize, usize, usize)> for Shape {
|
||||
fn from(d123456: (usize, usize, usize, usize, usize, usize)) -> Self {
|
||||
Self(vec![
|
||||
d123456.0, d123456.1, d123456.2, d123456.3, d123456.4, d123456.5,
|
||||
])
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Vec<usize>> for Shape {
|
||||
fn from(dims: Vec<usize>) -> Self {
|
||||
Self(dims)
|
||||
|
@ -15,7 +15,7 @@ pub mod model_sam;
|
||||
pub mod model_transformer;
|
||||
|
||||
use candle::{DType, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
|
||||
use candle_nn::{Linear, Module, VarBuilder};
|
||||
use clap::Parser;
|
||||
|
||||
pub fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> {
|
||||
@ -26,65 +26,74 @@ pub fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Resu
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct LayerNorm2d {
|
||||
weight: Tensor,
|
||||
bias: Tensor,
|
||||
num_channels: usize,
|
||||
eps: f64,
|
||||
}
|
||||
|
||||
impl LayerNorm2d {
|
||||
pub fn new(num_channels: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
|
||||
let weight = vb.get(num_channels, "weight")?;
|
||||
let bias = vb.get(num_channels, "bias")?;
|
||||
Ok(Self {
|
||||
weight,
|
||||
bias,
|
||||
num_channels,
|
||||
eps,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for LayerNorm2d {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let u = xs.mean_keepdim(1)?;
|
||||
let xs = xs.broadcast_sub(&u)?;
|
||||
let s = xs.sqr()?.mean_keepdim(1)?;
|
||||
let xs = xs.broadcast_div(&(s + self.eps)?.sqrt()?)?;
|
||||
xs.broadcast_mul(&self.weight.reshape((1, self.num_channels, 1, 1))?)?
|
||||
.broadcast_add(&self.bias.reshape((1, self.num_channels, 1, 1))?)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct MlpBlock {
|
||||
lin1: Linear,
|
||||
lin2: Linear,
|
||||
activation: candle_nn::Activation,
|
||||
}
|
||||
|
||||
impl MlpBlock {
|
||||
pub fn new(embedding_dim: usize, mlp_dim: usize, vb: VarBuilder) -> Result<Self> {
|
||||
pub fn new(
|
||||
embedding_dim: usize,
|
||||
mlp_dim: usize,
|
||||
activation: candle_nn::Activation,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let lin1 = candle_nn::linear(embedding_dim, mlp_dim, vb.pp("lin1"))?;
|
||||
let lin2 = candle_nn::linear(mlp_dim, embedding_dim, vb.pp("lin2"))?;
|
||||
Ok(Self { lin1, lin2 })
|
||||
Ok(Self {
|
||||
lin1,
|
||||
lin2,
|
||||
activation,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for MlpBlock {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
xs.apply(&self.lin1)?.gelu()?.apply(&self.lin2)
|
||||
xs.apply(&self.lin1)?
|
||||
.apply(&self.activation)?
|
||||
.apply(&self.lin2)
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
fn interpolate_pos_encoding(&self, xs: &Tensor, w: usize, h: usize) -> Result<Tensor> {
|
||||
let npatch = xs.dim(1)? - 1;
|
||||
let n = self.pos_embed.dim(1)? - 1;
|
||||
let sqrt_n = (n as f64).sqrt();
|
||||
if npatch == n && w == h {
|
||||
return Ok(xs.clone());
|
||||
}
|
||||
let class_pos_embed = self.pos_embed.i((.., ..1))?;
|
||||
let patch_pos_embed = self.pos_embed.i((.., 1..))?;
|
||||
let dim = xs.dim(D::Minus1)?;
|
||||
let (w0, h0) = ((w / PATCH_SIZE) as f64 + 0.1, (h / PATCH_SIZE) as f64 + 0.1);
|
||||
let patch_pos_embed = patch_pos_embed
|
||||
.reshape((1, sqrt_n as usize, sqrt_n as usize, dim))?
|
||||
.transpose(2, 3)?
|
||||
.transpose(1, 2)?;
|
||||
// This uses bicubic interpolation in the original implementation.
|
||||
let patch_pos_embed = patch_pos_embed.upsample_nearest2d(h0 as usize, w0 as usize)?;
|
||||
let el_count = patch_pos_embed.shape().elem_count();
|
||||
let patch_pos_embed =
|
||||
patch_pos_embed
|
||||
.transpose(1, 2)?
|
||||
.transpose(2, 3)?
|
||||
.reshape((1, el_count / dim, dim))?;
|
||||
Tensor::cat(&[&class_pos_embed, &patch_pos_embed], 1)
|
||||
}
|
||||
|
||||
fn prepare_tokens_with_mask(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let (_b, _nc, w, h) = xs.dims4()?;
|
||||
let xs = self.patch_embed.forward(xs)?;
|
||||
let xs = Tensor::cat(&[&self.cls_token, &xs], 1)?;
|
||||
&xs + &self.interpolate_pos_encoding(&xs, w, h)?
|
||||
}
|
||||
*/
|
||||
|
||||
#[derive(Parser)]
|
||||
struct Args {
|
||||
#[arg(long)]
|
||||
model: String,
|
||||
model: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
image: String,
|
||||
@ -99,13 +108,24 @@ pub fn main() -> anyhow::Result<()> {
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device);
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let weights = unsafe { candle::safetensors::MmapedFile::new(args.model)? };
|
||||
let model = match args.model {
|
||||
Some(model) => std::path::PathBuf::from(model),
|
||||
None => {
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.model("lmz/candle-sam".to_string());
|
||||
api.get("sam_vit_b_01ec64.safetensors")?
|
||||
}
|
||||
};
|
||||
let weights = unsafe { candle::safetensors::MmapedFile::new(model)? };
|
||||
let weights = weights.deserialize()?;
|
||||
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
|
||||
let _sam = model_sam::Sam::new(768, 12, 12, &[2, 5, 8, 11], vb)?; // sam_vit_b
|
||||
let sam = model_sam::Sam::new(768, 12, 12, &[2, 5, 8, 11], vb)?; // sam_vit_b
|
||||
|
||||
let (mask, iou_predictions) = sam.forward(&image, false)?;
|
||||
println!("mask: {mask:?}");
|
||||
println!("iou_predictions: {iou_predictions:?}");
|
||||
Ok(())
|
||||
}
|
||||
|
@ -70,6 +70,60 @@ impl Attention {
|
||||
rel_pos_hw,
|
||||
})
|
||||
}
|
||||
|
||||
fn add_decomposed_rel_pos(
|
||||
&self,
|
||||
attn: Tensor,
|
||||
q: &Tensor,
|
||||
(q_h, q_w): (usize, usize),
|
||||
(k_h, k_w): (usize, usize),
|
||||
) -> Result<Tensor> {
|
||||
match &self.rel_pos_hw {
|
||||
Some((rel_pos_h, rel_pos_w)) => {
|
||||
let r_h = get_rel_pos(q_h, k_h, rel_pos_h)?;
|
||||
let r_w = get_rel_pos(q_w, k_w, rel_pos_w)?;
|
||||
let (b, _, dim) = q.dims3()?;
|
||||
let r_q = q.reshape((b, q_h, q_w, dim))?;
|
||||
// rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
|
||||
let rel_h = r_q.matmul(&r_h.broadcast_left(b)?.t()?.contiguous()?)?;
|
||||
// rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
|
||||
let rel_w = r_q
|
||||
.transpose(1, 2)? // -> bwhc
|
||||
.contiguous()?
|
||||
.matmul(&r_w.broadcast_left(b)?.t()?.contiguous()?)? // bwhc,bwck -> bwhk
|
||||
.transpose(1, 2)?;
|
||||
(attn.reshape((b, q_h, q_w, k_h, k_w))?
|
||||
+ rel_h.unsqueeze(4)?.broadcast_add(&rel_w.unsqueeze(3)?)?)?
|
||||
.reshape((b, q_h * q_w, k_h * k_w))
|
||||
}
|
||||
None => Ok(attn),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn get_rel_pos(q_size: usize, k_size: usize, rel_pos: &Tensor) -> Result<Tensor> {
|
||||
let max_rel_dist = 2 * usize::max(q_size, k_size) - 1;
|
||||
let dev = rel_pos.device();
|
||||
let rel_pos_resized = if rel_pos.dim(0)? != max_rel_dist {
|
||||
todo!("interpolation")
|
||||
} else {
|
||||
rel_pos
|
||||
};
|
||||
let q_coords = Tensor::arange(0u32, q_size as u32, dev)?
|
||||
.reshape((q_size, 1))?
|
||||
.to_dtype(DType::F32)?;
|
||||
let k_coords = Tensor::arange(0u32, k_size as u32, dev)?
|
||||
.reshape((1, k_size))?
|
||||
.to_dtype(DType::F32)?;
|
||||
let q_coords = (q_coords * f64::max(1f64, k_size as f64 / q_size as f64))?;
|
||||
let k_coords = (k_coords * f64::max(1f64, q_size as f64 / k_size as f64))?;
|
||||
let relative_coords = (q_coords.broadcast_sub(&k_coords)?
|
||||
+ (k_size as f64 - 1.) * f64::max(1f64, q_size as f64 / k_size as f64))?;
|
||||
let (d1, d2) = relative_coords.dims2()?;
|
||||
let relative_coords = relative_coords.to_dtype(DType::U32)?;
|
||||
rel_pos_resized
|
||||
.index_select(&relative_coords.reshape(d1 * d2)?, 0)?
|
||||
.reshape((d1, d2, rel_pos_resized.dim(1)?))
|
||||
}
|
||||
|
||||
impl Module for Attention {
|
||||
@ -77,24 +131,22 @@ impl Module for Attention {
|
||||
let (b, h, w, c) = xs.dims4()?;
|
||||
let qkv = self
|
||||
.qkv
|
||||
.forward(xs)?
|
||||
.forward(&xs.flatten_to(1)?)?
|
||||
.reshape((b, h * w, 3, self.num_heads, c / self.num_heads))?
|
||||
.permute((2, 0, 3, 1, 4))?
|
||||
.reshape((3, b * self.num_heads, h * w, c / self.num_heads))?;
|
||||
let q = qkv.i(0)?;
|
||||
let k = qkv.i(1)?;
|
||||
let v = qkv.i(2)?;
|
||||
let attn = (q * self.scale)?.matmul(&k.t()?)?;
|
||||
if self.use_rel_pos {
|
||||
todo!()
|
||||
}
|
||||
let attn = (&q * self.scale)?.matmul(&k.t()?)?;
|
||||
let attn = self.add_decomposed_rel_pos(attn, &q, (h, w), (h, w))?;
|
||||
let attn = candle_nn::ops::softmax_last_dim(&attn)?;
|
||||
let attn = attn.matmul(&v)?;
|
||||
let attn = attn
|
||||
.matmul(&v)?
|
||||
.reshape((b, self.num_heads, h, w, c / self.num_heads))?
|
||||
.permute((0, 2, 3, 1, 4))?
|
||||
.reshape((b, h, w, c / self.num_heads))?;
|
||||
self.proj.forward(&attn)
|
||||
.reshape((b, h * w, c))?;
|
||||
self.proj.forward(&attn)?.reshape((b, h, w, c))
|
||||
}
|
||||
}
|
||||
|
||||
@ -117,8 +169,8 @@ impl Block {
|
||||
input_size: (usize, usize),
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let norm1 = layer_norm(dim, 1e-5, vb.pp("norm1"))?;
|
||||
let norm2 = layer_norm(dim, 1e-5, vb.pp("norm2"))?;
|
||||
let norm1 = layer_norm(dim, 1e-6, vb.pp("norm1"))?;
|
||||
let norm2 = layer_norm(dim, 1e-6, vb.pp("norm2"))?;
|
||||
let input_size_attn = if window_size == 0 {
|
||||
input_size
|
||||
} else {
|
||||
@ -132,7 +184,7 @@ impl Block {
|
||||
input_size_attn,
|
||||
vb.pp("attn"),
|
||||
)?;
|
||||
let mlp = crate::MlpBlock::new(dim, dim * 4, vb.pp("mlp"))?;
|
||||
let mlp = crate::MlpBlock::new(dim, dim * 4, candle_nn::Activation::Gelu, vb.pp("mlp"))?;
|
||||
Ok(Self {
|
||||
norm1,
|
||||
attn,
|
||||
@ -143,17 +195,76 @@ impl Block {
|
||||
}
|
||||
}
|
||||
|
||||
fn window_partition(xs: Tensor, window_size: usize) -> Result<(Tensor, (usize, usize))> {
|
||||
let (b, h, w, c) = xs.dims4()?;
|
||||
let pad_h = (window_size - h % window_size) % window_size;
|
||||
let pad_w = (window_size - w % window_size) % window_size;
|
||||
let xs = if pad_h > 0 {
|
||||
xs.pad_with_zeros(1, 0, pad_h)?
|
||||
} else {
|
||||
xs
|
||||
};
|
||||
let xs = if pad_w > 0 {
|
||||
xs.pad_with_zeros(2, 0, pad_w)?
|
||||
} else {
|
||||
xs
|
||||
};
|
||||
let (h_p, w_p) = (h + pad_h, w + pad_w);
|
||||
let windows = xs
|
||||
.reshape((
|
||||
b,
|
||||
h_p / window_size,
|
||||
window_size,
|
||||
w_p / window_size,
|
||||
window_size,
|
||||
c,
|
||||
))?
|
||||
.transpose(2, 3)?
|
||||
.contiguous()?
|
||||
.flatten_to(2)?;
|
||||
Ok((windows, (h_p, w_p)))
|
||||
}
|
||||
|
||||
fn window_unpartition(
|
||||
windows: Tensor,
|
||||
window_size: usize,
|
||||
(h_p, w_p): (usize, usize),
|
||||
(h, w): (usize, usize),
|
||||
) -> Result<Tensor> {
|
||||
let b = windows.dim(0)? / (h_p * w_p / window_size / window_size);
|
||||
let xs = windows
|
||||
.reshape((
|
||||
b,
|
||||
h_p / window_size,
|
||||
w_p / window_size,
|
||||
window_size,
|
||||
window_size,
|
||||
windows.elem_count() / b / h_p / w_p,
|
||||
))?
|
||||
.transpose(2, 3)?
|
||||
.contiguous()?
|
||||
.reshape((b, h_p, w_p, windows.elem_count() / b / h_p / w_p))?;
|
||||
let xs = if h_p > h { xs.narrow(1, 0, h)? } else { xs };
|
||||
let xs = if w_p > w { xs.narrow(2, 0, w)? } else { xs };
|
||||
Ok(xs)
|
||||
}
|
||||
|
||||
impl Module for Block {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let shortcut = xs;
|
||||
let xs = self.norm1.forward(xs)?;
|
||||
if self.window_size > 0 {
|
||||
todo!()
|
||||
}
|
||||
let hw = (xs.dim(1)?, xs.dim(2)?);
|
||||
let (xs, pad_hw) = if self.window_size > 0 {
|
||||
window_partition(xs, self.window_size)?
|
||||
} else {
|
||||
(xs, (0, 0))
|
||||
};
|
||||
let xs = self.attn.forward(&xs)?;
|
||||
if self.window_size > 0 {
|
||||
todo!()
|
||||
}
|
||||
let xs = if self.window_size > 0 {
|
||||
window_unpartition(xs, self.window_size, pad_hw, hw)?
|
||||
} else {
|
||||
xs
|
||||
};
|
||||
let xs = (xs + shortcut)?;
|
||||
&xs + xs.apply(&self.norm2)?.apply(&self.mlp)?
|
||||
}
|
||||
@ -165,9 +276,9 @@ pub struct ImageEncoderViT {
|
||||
patch_embed: PatchEmbed,
|
||||
blocks: Vec<Block>,
|
||||
neck_conv1: candle_nn::Conv2d,
|
||||
neck_ln1: LayerNorm,
|
||||
neck_ln1: crate::LayerNorm2d,
|
||||
neck_conv2: candle_nn::Conv2d,
|
||||
neck_ln2: LayerNorm,
|
||||
neck_ln2: crate::LayerNorm2d,
|
||||
pos_embed: Option<Tensor>,
|
||||
}
|
||||
|
||||
@ -222,13 +333,13 @@ impl ImageEncoderViT {
|
||||
Default::default(),
|
||||
vb.pp("neck.0"),
|
||||
)?;
|
||||
let neck_ln1 = layer_norm(out_chans, 1e-6, vb.pp("neck.1"))?;
|
||||
let neck_ln1 = crate::LayerNorm2d::new(out_chans, 1e-6, vb.pp("neck.1"))?;
|
||||
let cfg = candle_nn::Conv2dConfig {
|
||||
padding: 1,
|
||||
..Default::default()
|
||||
};
|
||||
let neck_conv2 = candle_nn::conv2d_no_bias(out_chans, out_chans, 3, cfg, vb.pp("neck.2"))?;
|
||||
let neck_ln2 = layer_norm(out_chans, 1e-6, vb.pp("neck.3"))?;
|
||||
let neck_ln2 = crate::LayerNorm2d::new(out_chans, 1e-6, vb.pp("neck.3"))?;
|
||||
let pos_embed = if use_abs_pos {
|
||||
let p = vb.get(
|
||||
(1, img_size / patch_size, img_size / patch_size, embed_dim),
|
||||
|
@ -1,5 +1,5 @@
|
||||
use candle::{DType, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
|
||||
use candle_nn::{Linear, Module, VarBuilder};
|
||||
|
||||
use crate::model_transformer::TwoWayTransformer;
|
||||
|
||||
@ -60,7 +60,7 @@ pub struct MaskDecoder {
|
||||
mask_tokens: candle_nn::Embedding,
|
||||
iou_prediction_head: MlpMaskDecoder,
|
||||
output_upscaling_conv1: candle_nn::ConvTranspose2d,
|
||||
output_upscaling_ln: LayerNorm,
|
||||
output_upscaling_ln: crate::LayerNorm2d,
|
||||
output_upscaling_conv2: candle_nn::ConvTranspose2d,
|
||||
num_mask_tokens: usize,
|
||||
output_hypernetworks_mlps: Vec<MlpMaskDecoder>,
|
||||
@ -99,7 +99,7 @@ impl MaskDecoder {
|
||||
vb.pp("output_upscaling.0"),
|
||||
)?;
|
||||
let output_upscaling_ln =
|
||||
layer_norm(transformer_dim / 4, 1e-6, vb.pp("output_upscaling.1"))?;
|
||||
crate::LayerNorm2d::new(transformer_dim / 4, 1e-6, vb.pp("output_upscaling.1"))?;
|
||||
let output_upscaling_conv2 = candle_nn::conv_transpose2d(
|
||||
transformer_dim / 4,
|
||||
transformer_dim / 8,
|
||||
@ -140,7 +140,7 @@ impl MaskDecoder {
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(
|
||||
pub fn forward(
|
||||
&self,
|
||||
image_embeddings: &Tensor,
|
||||
image_pe: &Tensor,
|
||||
@ -195,7 +195,7 @@ impl MaskDecoder {
|
||||
// Run the transformer
|
||||
let (hs, src) = self.transformer.forward(&src, &pos_src, &tokens)?;
|
||||
let iou_token_out = hs.i((.., 0))?;
|
||||
let mask_tokens_out = hs.i((.., 1, 1 + self.num_mask_tokens))?;
|
||||
let mask_tokens_out = hs.i((.., 1..1 + self.num_mask_tokens))?;
|
||||
|
||||
// Upscale mask embeddings and predict masks using the masks tokens.
|
||||
let src = src.transpose(1, 2)?.reshape((b, c, h, w))?;
|
||||
@ -213,9 +213,8 @@ impl MaskDecoder {
|
||||
}
|
||||
let hyper_in = Tensor::stack(hyper_in_list.as_slice(), 1)?;
|
||||
let (b, c, h, w) = upscaled_embedding.dims4()?;
|
||||
let masks = hyper_in
|
||||
.matmul(&upscaled_embedding.reshape((b, c, h * w))?)?
|
||||
.reshape((b, 0, h, w))?;
|
||||
let masks = hyper_in.matmul(&upscaled_embedding.reshape((b, c, h * w))?)?;
|
||||
let masks = masks.reshape((b, masks.elem_count() / b / h / w, h, w))?;
|
||||
|
||||
// Generate mask quality predictions.
|
||||
let iou_pred = self.iou_prediction_head.forward(&iou_token_out)?;
|
||||
@ -224,6 +223,9 @@ impl MaskDecoder {
|
||||
}
|
||||
|
||||
// Equivalent to torch.repeat_interleave
|
||||
fn repeat_interleave(_img: &Tensor, _repeats: usize, _dim: usize) -> Result<Tensor> {
|
||||
todo!()
|
||||
fn repeat_interleave(img: &Tensor, repeats: usize, dim: usize) -> Result<Tensor> {
|
||||
let img = img.unsqueeze(dim + 1)?;
|
||||
let mut dims = img.dims().to_vec();
|
||||
dims[dim + 1] = repeats;
|
||||
img.broadcast_as(dims)?.flatten(dim, dim + 1)
|
||||
}
|
||||
|
@ -1,5 +1,5 @@
|
||||
use candle::{DType, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
|
||||
use candle_nn::{Linear, Module, VarBuilder};
|
||||
|
||||
#[derive(Debug)]
|
||||
struct PostionEmbeddingRandom {
|
||||
@ -17,7 +17,7 @@ impl PostionEmbeddingRandom {
|
||||
|
||||
fn pe_encoding(&self, coords: &Tensor) -> Result<Tensor> {
|
||||
let coords = coords.affine(2., -1.)?;
|
||||
let coords = coords.matmul(&self.positional_encoding_gaussian_matrix)?;
|
||||
let coords = coords.broadcast_matmul(&self.positional_encoding_gaussian_matrix)?;
|
||||
let coords = (coords * (2. * std::f64::consts::PI))?;
|
||||
Tensor::cat(&[coords.sin()?, coords.cos()?], D::Minus1)
|
||||
}
|
||||
@ -25,12 +25,14 @@ impl PostionEmbeddingRandom {
|
||||
fn forward(&self, h: usize, w: usize) -> Result<Tensor> {
|
||||
let device = self.positional_encoding_gaussian_matrix.device();
|
||||
let grid = Tensor::ones((h, w), DType::F32, device)?;
|
||||
// TODO: cumsum
|
||||
let x_embed = (&grid - 0.5)?;
|
||||
// TODO: cumsum
|
||||
let y_embed = (&grid - 0.5)?;
|
||||
let x_embed = (x_embed / w as f64)?;
|
||||
let y_embed = (y_embed / h as f64)?;
|
||||
let x_embed = (Tensor::arange(0u32, w as u32, device)?.to_dtype(DType::F32)? + 0.5)?;
|
||||
let y_embed = (Tensor::arange(0u32, h as u32, device)?.to_dtype(DType::F32)? + 0.5)?;
|
||||
let x_embed = (x_embed / w as f64)?
|
||||
.reshape((1, w))?
|
||||
.broadcast_as((h, w))?;
|
||||
let y_embed = (y_embed / h as f64)?
|
||||
.reshape((h, 1))?
|
||||
.broadcast_as((h, w))?;
|
||||
let coords = Tensor::stack(&[&x_embed, &y_embed], D::Minus1)?;
|
||||
self.pe_encoding(&coords)?.permute((2, 0, 1))
|
||||
}
|
||||
@ -55,13 +57,14 @@ pub struct PromptEncoder {
|
||||
point_embeddings: Vec<candle_nn::Embedding>,
|
||||
not_a_point_embed: candle_nn::Embedding,
|
||||
mask_downscaling_conv1: candle_nn::Conv2d,
|
||||
mask_downscaling_ln1: LayerNorm,
|
||||
mask_downscaling_ln1: crate::LayerNorm2d,
|
||||
mask_downscaling_conv2: candle_nn::Conv2d,
|
||||
mask_downscaling_ln2: LayerNorm,
|
||||
mask_downscaling_ln2: crate::LayerNorm2d,
|
||||
mask_downscaling_conv3: candle_nn::Conv2d,
|
||||
no_mask_embed: candle_nn::Embedding,
|
||||
image_embedding_size: (usize, usize),
|
||||
input_image_size: (usize, usize),
|
||||
embed_dim: usize,
|
||||
}
|
||||
|
||||
impl PromptEncoder {
|
||||
@ -97,8 +100,9 @@ impl PromptEncoder {
|
||||
vb.pp("mask_downscaling.6"),
|
||||
)?;
|
||||
let mask_downscaling_ln1 =
|
||||
layer_norm(mask_in_chans / 4, 1e-6, vb.pp("mask_downscaling.1"))?;
|
||||
let mask_downscaling_ln2 = layer_norm(mask_in_chans, 1e-6, vb.pp("mask_downscaling.4"))?;
|
||||
crate::LayerNorm2d::new(mask_in_chans / 4, 1e-6, vb.pp("mask_downscaling.1"))?;
|
||||
let mask_downscaling_ln2 =
|
||||
crate::LayerNorm2d::new(mask_in_chans, 1e-6, vb.pp("mask_downscaling.4"))?;
|
||||
let mut point_embeddings = Vec::with_capacity(num_points_embeddings);
|
||||
let vb_e = vb.pp("point_embeddings");
|
||||
for i in 0..num_points_embeddings {
|
||||
@ -117,9 +121,16 @@ impl PromptEncoder {
|
||||
no_mask_embed,
|
||||
image_embedding_size,
|
||||
input_image_size,
|
||||
embed_dim,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn get_dense_pe(&self) -> Result<Tensor> {
|
||||
self.pe_layer
|
||||
.forward(self.image_embedding_size.0, self.image_embedding_size.1)?
|
||||
.unsqueeze(0)
|
||||
}
|
||||
|
||||
fn embed_masks(&self, masks: &Tensor) -> Result<Tensor> {
|
||||
masks
|
||||
.apply(&self.mask_downscaling_conv1)?
|
||||
@ -133,7 +144,16 @@ impl PromptEncoder {
|
||||
|
||||
fn embed_points(&self, points: &Tensor, labels: &Tensor, pad: bool) -> Result<Tensor> {
|
||||
let points = (points + 0.5)?;
|
||||
let points = if pad { todo!() } else { points };
|
||||
let dev = points.device();
|
||||
let (points, labels) = if pad {
|
||||
let padding_point = Tensor::zeros((points.dim(0)?, 1, 2), DType::F32, dev)?;
|
||||
let padding_label = (Tensor::ones((labels.dim(0)?, 1), DType::F32, dev)? * (-1f64))?;
|
||||
let points = Tensor::cat(&[&points, &padding_point], 1)?;
|
||||
let labels = Tensor::cat(&[labels, &padding_label], 1)?;
|
||||
(points, labels)
|
||||
} else {
|
||||
(points, labels.clone())
|
||||
};
|
||||
let point_embedding = self
|
||||
.pe_layer
|
||||
.forward_with_coords(&points, self.input_image_size)?;
|
||||
@ -154,7 +174,7 @@ impl PromptEncoder {
|
||||
Tensor::cat(&[&ce1, &ce2], 1)
|
||||
}
|
||||
|
||||
fn forward(
|
||||
pub fn forward(
|
||||
&self,
|
||||
points: Option<(&Tensor, &Tensor)>,
|
||||
boxes: Option<&Tensor>,
|
||||
@ -172,7 +192,9 @@ impl PromptEncoder {
|
||||
(Some(se_points), Some(se_boxes)) => Tensor::cat(&[se_points, se_boxes], 1)?,
|
||||
(Some(se_points), None) => se_points,
|
||||
(None, Some(se_boxes)) => se_boxes,
|
||||
(None, None) => Tensor::zeros(1, DType::F32, &candle::Device::Cpu)?,
|
||||
(None, None) => {
|
||||
Tensor::zeros((1, 0, self.embed_dim), DType::F32, &candle::Device::Cpu)?
|
||||
}
|
||||
};
|
||||
|
||||
let dense_embeddings = match masks {
|
||||
@ -180,7 +202,7 @@ impl PromptEncoder {
|
||||
let emb = self.no_mask_embed.embeddings();
|
||||
emb.reshape((1, emb.elem_count(), 1, 1))?.expand((
|
||||
1,
|
||||
0,
|
||||
emb.elem_count(),
|
||||
self.image_embedding_size.0,
|
||||
self.image_embedding_size.1,
|
||||
))?
|
||||
|
@ -5,6 +5,10 @@ use crate::model_image_encoder::ImageEncoderViT;
|
||||
use crate::model_mask_decoder::MaskDecoder;
|
||||
use crate::model_prompt_encoder::PromptEncoder;
|
||||
|
||||
const PROMPT_EMBED_DIM: usize = 256;
|
||||
const IMAGE_SIZE: usize = 1024;
|
||||
const VIT_PATCH_SIZE: usize = 16;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Sam {
|
||||
image_encoder: ImageEncoderViT,
|
||||
@ -22,10 +26,6 @@ impl Sam {
|
||||
encoder_global_attn_indexes: &[usize],
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
const PROMPT_EMBED_DIM: usize = 256;
|
||||
const IMAGE_SIZE: usize = 1024;
|
||||
const VIT_PATCH_SIZE: usize = 16;
|
||||
|
||||
let image_embedding_size = IMAGE_SIZE / VIT_PATCH_SIZE;
|
||||
|
||||
let image_encoder = ImageEncoderViT::new(
|
||||
@ -69,4 +69,33 @@ impl Sam {
|
||||
pixel_mean,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, img: &Tensor, multimask_output: bool) -> Result<(Tensor, Tensor)> {
|
||||
let img = self.preprocess(img)?.unsqueeze(0)?;
|
||||
let img_embeddings = self.image_encoder.forward(&img)?;
|
||||
let image_pe = self.prompt_encoder.get_dense_pe()?;
|
||||
let (sparse_prompt_embeddings, dense_prompt_embeddings) =
|
||||
self.prompt_encoder.forward(None, None, None)?;
|
||||
let (low_res_mask, iou_predictions) = self.mask_decoder.forward(
|
||||
&img_embeddings,
|
||||
&image_pe,
|
||||
&sparse_prompt_embeddings,
|
||||
&dense_prompt_embeddings,
|
||||
multimask_output,
|
||||
)?;
|
||||
// TODO: post-processing.
|
||||
Ok((low_res_mask, iou_predictions))
|
||||
}
|
||||
|
||||
fn preprocess(&self, img: &Tensor) -> Result<Tensor> {
|
||||
let (c, h, w) = img.dims3()?;
|
||||
let img = img
|
||||
.broadcast_sub(&self.pixel_mean)?
|
||||
.broadcast_div(&self.pixel_std)?;
|
||||
if h > IMAGE_SIZE || w > IMAGE_SIZE {
|
||||
candle::bail!("image is too large ({w}, {h}), maximum size {IMAGE_SIZE}")
|
||||
}
|
||||
let img = img.pad_with_zeros(1, 0, IMAGE_SIZE - h)?;
|
||||
img.pad_with_zeros(2, 0, IMAGE_SIZE - w)
|
||||
}
|
||||
}
|
||||
|
@ -36,7 +36,8 @@ impl Attention {
|
||||
fn separate_heads(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let (b, n, c) = x.dims3()?;
|
||||
x.reshape((b, n, self.num_heads, c / self.num_heads))?
|
||||
.transpose(1, 2)
|
||||
.transpose(1, 2)?
|
||||
.contiguous()
|
||||
}
|
||||
|
||||
fn recombine_heads(&self, x: &Tensor) -> Result<Tensor> {
|
||||
@ -102,8 +103,12 @@ impl TwoWayAttentionBlock {
|
||||
2,
|
||||
vb.pp("cross_attn_image_to_token"),
|
||||
)?;
|
||||
// TODO: use relu in this mlp
|
||||
let mlp = crate::MlpBlock::new(embedding_dim, mlp_dim, vb.pp("mlp"))?;
|
||||
let mlp = crate::MlpBlock::new(
|
||||
embedding_dim,
|
||||
mlp_dim,
|
||||
candle_nn::Activation::Relu,
|
||||
vb.pp("mlp"),
|
||||
)?;
|
||||
Ok(Self {
|
||||
self_attn,
|
||||
norm1,
|
||||
@ -126,7 +131,7 @@ impl TwoWayAttentionBlock {
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
// Self attention block
|
||||
let queries = if self.skip_first_layer_pe {
|
||||
self.self_attn.forward(queries, keys, queries)?
|
||||
self.self_attn.forward(queries, queries, queries)?
|
||||
} else {
|
||||
let q = (queries + query_pe)?;
|
||||
let attn_out = self.self_attn.forward(&q, &q, queries)?;
|
||||
|
@ -28,7 +28,7 @@
|
||||
//! ```
|
||||
//!
|
||||
//! [`Layer Normalization`]: https://arxiv.org/abs/1607.06450
|
||||
use candle::{DType, Result, Tensor};
|
||||
use candle::{DType, Result, Tensor, D};
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub struct LayerNormConfig {
|
||||
@ -104,15 +104,15 @@ impl crate::Module for LayerNorm {
|
||||
DType::F16 | DType::BF16 => DType::F32,
|
||||
d => d,
|
||||
};
|
||||
let (_bsize, _seq_len, hidden_size) = x.dims3()?;
|
||||
let hidden_size = x.dim(D::Minus1)?;
|
||||
let x = x.to_dtype(internal_dtype)?;
|
||||
let x = if self.remove_mean {
|
||||
let mean_x = (x.sum_keepdim(2)? / hidden_size as f64)?;
|
||||
let mean_x = (x.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
|
||||
x.broadcast_sub(&mean_x)?
|
||||
} else {
|
||||
x
|
||||
};
|
||||
let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?;
|
||||
let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
|
||||
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
|
||||
let x = x_normed.to_dtype(x_dtype)?.broadcast_mul(&self.weight)?;
|
||||
match &self.bias {
|
||||
|
@ -41,8 +41,9 @@ impl Linear {
|
||||
|
||||
impl super::Module for Linear {
|
||||
fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
|
||||
let w = match x.dims() {
|
||||
&[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?,
|
||||
let w = match *x.dims() {
|
||||
[b1, b2, _, _] => self.weight.broadcast_left((b1, b2))?.t()?,
|
||||
[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?,
|
||||
_ => self.weight.t()?,
|
||||
};
|
||||
let x = x.matmul(&w)?;
|
||||
|
Reference in New Issue
Block a user