mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
More segment-anything again. (#764)
* More segment-anything again. * Transformer block forward. * Two-ways transformer. * Position embeddings. * Sketch the prompt encoder. * More prompt-encoder. * More prompt-encoder. * Add the main sam module. * Embed the transformer. * And hook the transformer forward step. * Build the model. * Handle the global attn indexes. * Get the model to load.
This commit is contained in:
@ -8,9 +8,11 @@ extern crate intel_mkl_src;
|
|||||||
#[cfg(feature = "accelerate")]
|
#[cfg(feature = "accelerate")]
|
||||||
extern crate accelerate_src;
|
extern crate accelerate_src;
|
||||||
|
|
||||||
mod model_image_encoder;
|
pub mod model_image_encoder;
|
||||||
mod model_mask_decoder;
|
pub mod model_mask_decoder;
|
||||||
mod model_transformer;
|
pub mod model_prompt_encoder;
|
||||||
|
pub mod model_sam;
|
||||||
|
pub mod model_transformer;
|
||||||
|
|
||||||
use candle::{DType, IndexOp, Result, Tensor, D};
|
use candle::{DType, IndexOp, Result, Tensor, D};
|
||||||
use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
|
use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
|
||||||
@ -82,7 +84,7 @@ impl Module for MlpBlock {
|
|||||||
#[derive(Parser)]
|
#[derive(Parser)]
|
||||||
struct Args {
|
struct Args {
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
model: Option<String>,
|
model: String,
|
||||||
|
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
image: String,
|
image: String,
|
||||||
@ -95,10 +97,15 @@ struct Args {
|
|||||||
pub fn main() -> anyhow::Result<()> {
|
pub fn main() -> anyhow::Result<()> {
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
|
|
||||||
let _device = candle_examples::device(args.cpu)?;
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
|
||||||
let image = candle_examples::imagenet::load_image224(args.image)?;
|
let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device);
|
||||||
println!("loaded image {image:?}");
|
println!("loaded image {image:?}");
|
||||||
|
|
||||||
|
let weights = unsafe { candle::safetensors::MmapedFile::new(args.model)? };
|
||||||
|
let weights = weights.deserialize()?;
|
||||||
|
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
|
||||||
|
let _sam = model_sam::Sam::new(768, 12, 12, &[2, 5, 8, 11], vb)?; // sam_vit_b
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -47,7 +47,7 @@ impl Attention {
|
|||||||
num_heads: usize,
|
num_heads: usize,
|
||||||
qkv_bias: bool,
|
qkv_bias: bool,
|
||||||
use_rel_pos: bool,
|
use_rel_pos: bool,
|
||||||
window_size: usize,
|
input_size: (usize, usize),
|
||||||
vb: VarBuilder,
|
vb: VarBuilder,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let qkv = crate::linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?;
|
let qkv = crate::linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?;
|
||||||
@ -55,8 +55,8 @@ impl Attention {
|
|||||||
let head_dim = dim / num_heads;
|
let head_dim = dim / num_heads;
|
||||||
let scale = 1. / (head_dim as f64).sqrt();
|
let scale = 1. / (head_dim as f64).sqrt();
|
||||||
let rel_pos_hw = if use_rel_pos {
|
let rel_pos_hw = if use_rel_pos {
|
||||||
let h = vb.get((2 * window_size - 1, head_dim), "rel_pos_h")?;
|
let h = vb.get((2 * input_size.0 - 1, head_dim), "rel_pos_h")?;
|
||||||
let w = vb.get((2 * window_size - 1, head_dim), "rel_pos_w")?;
|
let w = vb.get((2 * input_size.1 - 1, head_dim), "rel_pos_w")?;
|
||||||
Some((h, w))
|
Some((h, w))
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
@ -114,16 +114,22 @@ impl Block {
|
|||||||
qkv_bias: bool,
|
qkv_bias: bool,
|
||||||
use_rel_pos: bool,
|
use_rel_pos: bool,
|
||||||
window_size: usize,
|
window_size: usize,
|
||||||
|
input_size: (usize, usize),
|
||||||
vb: VarBuilder,
|
vb: VarBuilder,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let norm1 = layer_norm(dim, 1e-5, vb.pp("norm1"))?;
|
let norm1 = layer_norm(dim, 1e-5, vb.pp("norm1"))?;
|
||||||
let norm2 = layer_norm(dim, 1e-5, vb.pp("norm2"))?;
|
let norm2 = layer_norm(dim, 1e-5, vb.pp("norm2"))?;
|
||||||
|
let input_size_attn = if window_size == 0 {
|
||||||
|
input_size
|
||||||
|
} else {
|
||||||
|
(window_size, window_size)
|
||||||
|
};
|
||||||
let attn = Attention::new(
|
let attn = Attention::new(
|
||||||
dim,
|
dim,
|
||||||
num_heads,
|
num_heads,
|
||||||
qkv_bias,
|
qkv_bias,
|
||||||
use_rel_pos,
|
use_rel_pos,
|
||||||
window_size,
|
input_size_attn,
|
||||||
vb.pp("attn"),
|
vb.pp("attn"),
|
||||||
)?;
|
)?;
|
||||||
let mlp = crate::MlpBlock::new(dim, dim * 4, vb.pp("mlp"))?;
|
let mlp = crate::MlpBlock::new(dim, dim * 4, vb.pp("mlp"))?;
|
||||||
@ -154,7 +160,7 @@ impl Module for Block {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct ImageEncoderViT {
|
pub struct ImageEncoderViT {
|
||||||
img_size: usize,
|
img_size: usize,
|
||||||
patch_embed: PatchEmbed,
|
patch_embed: PatchEmbed,
|
||||||
blocks: Vec<Block>,
|
blocks: Vec<Block>,
|
||||||
@ -167,7 +173,7 @@ struct ImageEncoderViT {
|
|||||||
|
|
||||||
impl ImageEncoderViT {
|
impl ImageEncoderViT {
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
fn new(
|
pub fn new(
|
||||||
img_size: usize,
|
img_size: usize,
|
||||||
patch_size: usize,
|
patch_size: usize,
|
||||||
in_chans: usize,
|
in_chans: usize,
|
||||||
@ -179,6 +185,7 @@ impl ImageEncoderViT {
|
|||||||
use_rel_pos: bool,
|
use_rel_pos: bool,
|
||||||
use_abs_pos: bool,
|
use_abs_pos: bool,
|
||||||
window_size: usize,
|
window_size: usize,
|
||||||
|
global_attn_indexes: &[usize],
|
||||||
vb: VarBuilder,
|
vb: VarBuilder,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let patch_embed = PatchEmbed::new(
|
let patch_embed = PatchEmbed::new(
|
||||||
@ -192,12 +199,18 @@ impl ImageEncoderViT {
|
|||||||
let mut blocks = Vec::with_capacity(depth);
|
let mut blocks = Vec::with_capacity(depth);
|
||||||
let vb_b = vb.pp("blocks");
|
let vb_b = vb.pp("blocks");
|
||||||
for i in 0..depth {
|
for i in 0..depth {
|
||||||
|
let window_size = if global_attn_indexes.contains(&i) {
|
||||||
|
0
|
||||||
|
} else {
|
||||||
|
window_size
|
||||||
|
};
|
||||||
let block = Block::new(
|
let block = Block::new(
|
||||||
embed_dim,
|
embed_dim,
|
||||||
num_heads,
|
num_heads,
|
||||||
qkv_bias,
|
qkv_bias,
|
||||||
use_rel_pos,
|
use_rel_pos,
|
||||||
window_size,
|
window_size,
|
||||||
|
(img_size / patch_size, img_size / patch_size),
|
||||||
vb_b.pp(i),
|
vb_b.pp(i),
|
||||||
)?;
|
)?;
|
||||||
blocks.push(block)
|
blocks.push(block)
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
use candle::{DType, IndexOp, Result, Tensor, D};
|
use candle::{DType, IndexOp, Result, Tensor, D};
|
||||||
use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
|
use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
|
||||||
|
|
||||||
|
use crate::model_transformer::TwoWayTransformer;
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct MlpMaskDecoder {
|
struct MlpMaskDecoder {
|
||||||
layers: Vec<Linear>,
|
layers: Vec<Linear>,
|
||||||
@ -53,7 +55,7 @@ impl Module for MlpMaskDecoder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct MaskDecoder {
|
pub struct MaskDecoder {
|
||||||
iou_token: candle_nn::Embedding,
|
iou_token: candle_nn::Embedding,
|
||||||
mask_tokens: candle_nn::Embedding,
|
mask_tokens: candle_nn::Embedding,
|
||||||
iou_prediction_head: MlpMaskDecoder,
|
iou_prediction_head: MlpMaskDecoder,
|
||||||
@ -62,17 +64,18 @@ struct MaskDecoder {
|
|||||||
output_upscaling_conv2: candle_nn::ConvTranspose2d,
|
output_upscaling_conv2: candle_nn::ConvTranspose2d,
|
||||||
num_mask_tokens: usize,
|
num_mask_tokens: usize,
|
||||||
output_hypernetworks_mlps: Vec<MlpMaskDecoder>,
|
output_hypernetworks_mlps: Vec<MlpMaskDecoder>,
|
||||||
|
transformer: TwoWayTransformer,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl MaskDecoder {
|
impl MaskDecoder {
|
||||||
fn new(
|
pub fn new(
|
||||||
transformer_dim: usize,
|
transformer_dim: usize,
|
||||||
num_multimask_outputs: usize,
|
num_multimask_outputs: usize,
|
||||||
iou_head_depth: usize,
|
iou_head_depth: usize,
|
||||||
iou_head_hidden_dim: usize,
|
iou_head_hidden_dim: usize,
|
||||||
vb: VarBuilder,
|
vb: VarBuilder,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let num_mask_tokens = num_multimask_outputs - 1;
|
let num_mask_tokens = num_multimask_outputs + 1;
|
||||||
let iou_prediction_head = MlpMaskDecoder::new(
|
let iou_prediction_head = MlpMaskDecoder::new(
|
||||||
transformer_dim,
|
transformer_dim,
|
||||||
iou_head_hidden_dim,
|
iou_head_hidden_dim,
|
||||||
@ -117,6 +120,13 @@ impl MaskDecoder {
|
|||||||
)?;
|
)?;
|
||||||
output_hypernetworks_mlps.push(mlp)
|
output_hypernetworks_mlps.push(mlp)
|
||||||
}
|
}
|
||||||
|
let transformer = TwoWayTransformer::new(
|
||||||
|
/* depth */ 2,
|
||||||
|
/* embedding_dim */ transformer_dim,
|
||||||
|
/* num_heads */ 8,
|
||||||
|
/* mlp_dim */ 2048,
|
||||||
|
vb.pp("transformer"),
|
||||||
|
)?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
iou_token,
|
iou_token,
|
||||||
mask_tokens,
|
mask_tokens,
|
||||||
@ -126,6 +136,7 @@ impl MaskDecoder {
|
|||||||
output_upscaling_conv2,
|
output_upscaling_conv2,
|
||||||
num_mask_tokens,
|
num_mask_tokens,
|
||||||
output_hypernetworks_mlps,
|
output_hypernetworks_mlps,
|
||||||
|
transformer,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -182,7 +193,7 @@ impl MaskDecoder {
|
|||||||
let (b, c, h, w) = src.dims4()?;
|
let (b, c, h, w) = src.dims4()?;
|
||||||
|
|
||||||
// Run the transformer
|
// Run the transformer
|
||||||
let (hs, src) = run_transformer(&src, &pos_src, &tokens)?;
|
let (hs, src) = self.transformer.forward(&src, &pos_src, &tokens)?;
|
||||||
let iou_token_out = hs.i((.., 0))?;
|
let iou_token_out = hs.i((.., 0))?;
|
||||||
let mask_tokens_out = hs.i((.., 1, 1 + self.num_mask_tokens))?;
|
let mask_tokens_out = hs.i((.., 1, 1 + self.num_mask_tokens))?;
|
||||||
|
|
||||||
@ -216,7 +227,3 @@ impl MaskDecoder {
|
|||||||
fn repeat_interleave(_img: &Tensor, _repeats: usize, _dim: usize) -> Result<Tensor> {
|
fn repeat_interleave(_img: &Tensor, _repeats: usize, _dim: usize) -> Result<Tensor> {
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run_transformer(_src: &Tensor, _pos: &Tensor, _tokens: &Tensor) -> Result<(Tensor, Tensor)> {
|
|
||||||
todo!()
|
|
||||||
}
|
|
||||||
|
@ -0,0 +1,192 @@
|
|||||||
|
use candle::{DType, IndexOp, Result, Tensor, D};
|
||||||
|
use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct PostionEmbeddingRandom {
|
||||||
|
positional_encoding_gaussian_matrix: Tensor,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PostionEmbeddingRandom {
|
||||||
|
fn new(num_pos_feats: usize, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let positional_encoding_gaussian_matrix =
|
||||||
|
vb.get((2, num_pos_feats), "positional_encoding_gaussian_matrix")?;
|
||||||
|
Ok(Self {
|
||||||
|
positional_encoding_gaussian_matrix,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn pe_encoding(&self, coords: &Tensor) -> Result<Tensor> {
|
||||||
|
let coords = coords.affine(2., -1.)?;
|
||||||
|
let coords = coords.matmul(&self.positional_encoding_gaussian_matrix)?;
|
||||||
|
let coords = (coords * (2. * std::f64::consts::PI))?;
|
||||||
|
Tensor::cat(&[coords.sin()?, coords.cos()?], D::Minus1)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(&self, h: usize, w: usize) -> Result<Tensor> {
|
||||||
|
let device = self.positional_encoding_gaussian_matrix.device();
|
||||||
|
let grid = Tensor::ones((h, w), DType::F32, device)?;
|
||||||
|
// TODO: cumsum
|
||||||
|
let x_embed = (&grid - 0.5)?;
|
||||||
|
// TODO: cumsum
|
||||||
|
let y_embed = (&grid - 0.5)?;
|
||||||
|
let x_embed = (x_embed / w as f64)?;
|
||||||
|
let y_embed = (y_embed / h as f64)?;
|
||||||
|
let coords = Tensor::stack(&[&x_embed, &y_embed], D::Minus1)?;
|
||||||
|
self.pe_encoding(&coords)?.permute((2, 0, 1))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward_with_coords(
|
||||||
|
&self,
|
||||||
|
coords_input: &Tensor,
|
||||||
|
image_size: (usize, usize),
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
let coords0 = (coords_input.narrow(D::Minus1, 0, 1)? / image_size.1 as f64)?;
|
||||||
|
let coords1 = (coords_input.narrow(D::Minus1, 1, 1)? / image_size.0 as f64)?;
|
||||||
|
let c = coords_input.dim(D::Minus1)?;
|
||||||
|
let coords_rest = coords_input.narrow(D::Minus1, 2, c - 2)?;
|
||||||
|
let coords = Tensor::cat(&[&coords0, &coords1, &coords_rest], D::Minus1)?;
|
||||||
|
self.pe_encoding(&coords)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct PromptEncoder {
|
||||||
|
pe_layer: PostionEmbeddingRandom,
|
||||||
|
point_embeddings: Vec<candle_nn::Embedding>,
|
||||||
|
not_a_point_embed: candle_nn::Embedding,
|
||||||
|
mask_downscaling_conv1: candle_nn::Conv2d,
|
||||||
|
mask_downscaling_ln1: LayerNorm,
|
||||||
|
mask_downscaling_conv2: candle_nn::Conv2d,
|
||||||
|
mask_downscaling_ln2: LayerNorm,
|
||||||
|
mask_downscaling_conv3: candle_nn::Conv2d,
|
||||||
|
no_mask_embed: candle_nn::Embedding,
|
||||||
|
image_embedding_size: (usize, usize),
|
||||||
|
input_image_size: (usize, usize),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PromptEncoder {
|
||||||
|
pub fn new(
|
||||||
|
embed_dim: usize,
|
||||||
|
image_embedding_size: (usize, usize),
|
||||||
|
input_image_size: (usize, usize),
|
||||||
|
mask_in_chans: usize,
|
||||||
|
vb: VarBuilder,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let num_points_embeddings = 4;
|
||||||
|
let pe_layer = PostionEmbeddingRandom::new(embed_dim / 2, vb.pp("pe_layer"))?;
|
||||||
|
let not_a_point_embed = candle_nn::embedding(1, embed_dim, vb.pp("not_a_point_embed"))?;
|
||||||
|
let no_mask_embed = candle_nn::embedding(1, embed_dim, vb.pp("no_mask_embed"))?;
|
||||||
|
let cfg = candle_nn::Conv2dConfig {
|
||||||
|
stride: 2,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let mask_downscaling_conv1 =
|
||||||
|
candle_nn::conv2d(1, mask_in_chans / 4, 2, cfg, vb.pp("mask_downscaling.0"))?;
|
||||||
|
let mask_downscaling_conv2 = candle_nn::conv2d(
|
||||||
|
mask_in_chans / 4,
|
||||||
|
mask_in_chans,
|
||||||
|
2,
|
||||||
|
cfg,
|
||||||
|
vb.pp("mask_downscaling.3"),
|
||||||
|
)?;
|
||||||
|
let mask_downscaling_conv3 = candle_nn::conv2d(
|
||||||
|
mask_in_chans,
|
||||||
|
embed_dim,
|
||||||
|
1,
|
||||||
|
Default::default(),
|
||||||
|
vb.pp("mask_downscaling.6"),
|
||||||
|
)?;
|
||||||
|
let mask_downscaling_ln1 =
|
||||||
|
layer_norm(mask_in_chans / 4, 1e-6, vb.pp("mask_downscaling.1"))?;
|
||||||
|
let mask_downscaling_ln2 = layer_norm(mask_in_chans, 1e-6, vb.pp("mask_downscaling.4"))?;
|
||||||
|
let mut point_embeddings = Vec::with_capacity(num_points_embeddings);
|
||||||
|
let vb_e = vb.pp("point_embeddings");
|
||||||
|
for i in 0..num_points_embeddings {
|
||||||
|
let emb = candle_nn::embedding(1, embed_dim, vb_e.pp(i))?;
|
||||||
|
point_embeddings.push(emb)
|
||||||
|
}
|
||||||
|
Ok(Self {
|
||||||
|
pe_layer,
|
||||||
|
point_embeddings,
|
||||||
|
not_a_point_embed,
|
||||||
|
mask_downscaling_conv1,
|
||||||
|
mask_downscaling_ln1,
|
||||||
|
mask_downscaling_conv2,
|
||||||
|
mask_downscaling_ln2,
|
||||||
|
mask_downscaling_conv3,
|
||||||
|
no_mask_embed,
|
||||||
|
image_embedding_size,
|
||||||
|
input_image_size,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn embed_masks(&self, masks: &Tensor) -> Result<Tensor> {
|
||||||
|
masks
|
||||||
|
.apply(&self.mask_downscaling_conv1)?
|
||||||
|
.apply(&self.mask_downscaling_ln1)?
|
||||||
|
.gelu()?
|
||||||
|
.apply(&self.mask_downscaling_conv2)?
|
||||||
|
.apply(&self.mask_downscaling_ln2)?
|
||||||
|
.gelu()?
|
||||||
|
.apply(&self.mask_downscaling_conv3)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn embed_points(&self, points: &Tensor, labels: &Tensor, pad: bool) -> Result<Tensor> {
|
||||||
|
let points = (points + 0.5)?;
|
||||||
|
let points = if pad { todo!() } else { points };
|
||||||
|
let point_embedding = self
|
||||||
|
.pe_layer
|
||||||
|
.forward_with_coords(&points, self.input_image_size)?;
|
||||||
|
// TODO: tweak based on labels.
|
||||||
|
Ok(point_embedding)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn embed_boxes(&self, boxes: &Tensor) -> Result<Tensor> {
|
||||||
|
let boxes = (boxes + 0.5)?;
|
||||||
|
let coords = boxes.reshape((boxes.elem_count() / 4, 2, 2))?;
|
||||||
|
let corner_embedding = self
|
||||||
|
.pe_layer
|
||||||
|
.forward_with_coords(&coords, self.input_image_size)?;
|
||||||
|
let ce1 = corner_embedding.i((.., 0))?;
|
||||||
|
let ce2 = corner_embedding.i((.., 1))?;
|
||||||
|
let ce1 = (ce1 + self.point_embeddings[2].embeddings())?;
|
||||||
|
let ce2 = (ce2 + self.point_embeddings[3].embeddings())?;
|
||||||
|
Tensor::cat(&[&ce1, &ce2], 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(
|
||||||
|
&self,
|
||||||
|
points: Option<(&Tensor, &Tensor)>,
|
||||||
|
boxes: Option<&Tensor>,
|
||||||
|
masks: Option<&Tensor>,
|
||||||
|
) -> Result<(Tensor, Tensor)> {
|
||||||
|
let se_points = match points {
|
||||||
|
Some((coords, labels)) => Some(self.embed_points(coords, labels, boxes.is_none())?),
|
||||||
|
None => None,
|
||||||
|
};
|
||||||
|
let se_boxes = match boxes {
|
||||||
|
Some(boxes) => Some(self.embed_boxes(boxes)?),
|
||||||
|
None => None,
|
||||||
|
};
|
||||||
|
let sparse_embeddings = match (se_points, se_boxes) {
|
||||||
|
(Some(se_points), Some(se_boxes)) => Tensor::cat(&[se_points, se_boxes], 1)?,
|
||||||
|
(Some(se_points), None) => se_points,
|
||||||
|
(None, Some(se_boxes)) => se_boxes,
|
||||||
|
(None, None) => Tensor::zeros(1, DType::F32, &candle::Device::Cpu)?,
|
||||||
|
};
|
||||||
|
|
||||||
|
let dense_embeddings = match masks {
|
||||||
|
None => {
|
||||||
|
let emb = self.no_mask_embed.embeddings();
|
||||||
|
emb.reshape((1, emb.elem_count(), 1, 1))?.expand((
|
||||||
|
1,
|
||||||
|
0,
|
||||||
|
self.image_embedding_size.0,
|
||||||
|
self.image_embedding_size.1,
|
||||||
|
))?
|
||||||
|
}
|
||||||
|
Some(masks) => self.embed_masks(masks)?,
|
||||||
|
};
|
||||||
|
Ok((sparse_embeddings, dense_embeddings))
|
||||||
|
}
|
||||||
|
}
|
72
candle-examples/examples/segment-anything/model_sam.rs
Normal file
72
candle-examples/examples/segment-anything/model_sam.rs
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
use candle::{DType, IndexOp, Result, Tensor, D};
|
||||||
|
use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
|
||||||
|
|
||||||
|
use crate::model_image_encoder::ImageEncoderViT;
|
||||||
|
use crate::model_mask_decoder::MaskDecoder;
|
||||||
|
use crate::model_prompt_encoder::PromptEncoder;
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct Sam {
|
||||||
|
image_encoder: ImageEncoderViT,
|
||||||
|
prompt_encoder: PromptEncoder,
|
||||||
|
mask_decoder: MaskDecoder,
|
||||||
|
pixel_mean: Tensor,
|
||||||
|
pixel_std: Tensor,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Sam {
|
||||||
|
pub fn new(
|
||||||
|
encoder_embed_dim: usize,
|
||||||
|
encoder_depth: usize,
|
||||||
|
encoder_num_heads: usize,
|
||||||
|
encoder_global_attn_indexes: &[usize],
|
||||||
|
vb: VarBuilder,
|
||||||
|
) -> Result<Self> {
|
||||||
|
const PROMPT_EMBED_DIM: usize = 256;
|
||||||
|
const IMAGE_SIZE: usize = 1024;
|
||||||
|
const VIT_PATCH_SIZE: usize = 16;
|
||||||
|
|
||||||
|
let image_embedding_size = IMAGE_SIZE / VIT_PATCH_SIZE;
|
||||||
|
|
||||||
|
let image_encoder = ImageEncoderViT::new(
|
||||||
|
IMAGE_SIZE,
|
||||||
|
VIT_PATCH_SIZE,
|
||||||
|
3,
|
||||||
|
encoder_embed_dim,
|
||||||
|
encoder_depth,
|
||||||
|
encoder_num_heads,
|
||||||
|
PROMPT_EMBED_DIM,
|
||||||
|
/* qkv_bias */ true,
|
||||||
|
/* use_rel_pos */ true,
|
||||||
|
/* use_abs_pos */ true,
|
||||||
|
/* window_size */ 14,
|
||||||
|
/* global_attn_indexes */ encoder_global_attn_indexes,
|
||||||
|
vb.pp("image_encoder"),
|
||||||
|
)?;
|
||||||
|
let prompt_encoder = PromptEncoder::new(
|
||||||
|
PROMPT_EMBED_DIM,
|
||||||
|
(image_embedding_size, image_embedding_size),
|
||||||
|
(IMAGE_SIZE, IMAGE_SIZE),
|
||||||
|
16,
|
||||||
|
vb.pp("prompt_encoder"),
|
||||||
|
)?;
|
||||||
|
let mask_decoder = MaskDecoder::new(
|
||||||
|
PROMPT_EMBED_DIM,
|
||||||
|
/* num_multitask_outputs */ 3,
|
||||||
|
/* iou_head_depth */ 3,
|
||||||
|
/* iou_head_hidden_dim */ 256,
|
||||||
|
vb.pp("mask_decoder"),
|
||||||
|
)?;
|
||||||
|
let pixel_mean =
|
||||||
|
Tensor::new(&[123.675f32, 116.28, 103.53], vb.device())?.reshape((3, 1, 1))?;
|
||||||
|
let pixel_std =
|
||||||
|
Tensor::new(&[58.395f32, 57.12, 57.375], vb.device())?.reshape((3, 1, 1))?;
|
||||||
|
Ok(Self {
|
||||||
|
image_encoder,
|
||||||
|
prompt_encoder,
|
||||||
|
mask_decoder,
|
||||||
|
pixel_std,
|
||||||
|
pixel_mean,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
@ -75,3 +75,146 @@ struct TwoWayAttentionBlock {
|
|||||||
cross_attn_image_to_token: Attention,
|
cross_attn_image_to_token: Attention,
|
||||||
skip_first_layer_pe: bool,
|
skip_first_layer_pe: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl TwoWayAttentionBlock {
|
||||||
|
fn new(
|
||||||
|
embedding_dim: usize,
|
||||||
|
num_heads: usize,
|
||||||
|
mlp_dim: usize,
|
||||||
|
skip_first_layer_pe: bool,
|
||||||
|
vb: VarBuilder,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let self_attn = Attention::new(embedding_dim, num_heads, 1, vb.pp("self_attn"))?;
|
||||||
|
let norm1 = layer_norm(embedding_dim, 1e-5, vb.pp("norm1"))?;
|
||||||
|
let norm2 = layer_norm(embedding_dim, 1e-5, vb.pp("norm2"))?;
|
||||||
|
let norm3 = layer_norm(embedding_dim, 1e-5, vb.pp("norm3"))?;
|
||||||
|
let norm4 = layer_norm(embedding_dim, 1e-5, vb.pp("norm4"))?;
|
||||||
|
let self_attn = Attention::new(embedding_dim, num_heads, 1, vb.pp("self_attn"))?;
|
||||||
|
let cross_attn_token_to_image = Attention::new(
|
||||||
|
embedding_dim,
|
||||||
|
num_heads,
|
||||||
|
2,
|
||||||
|
vb.pp("cross_attn_token_to_image"),
|
||||||
|
)?;
|
||||||
|
let cross_attn_image_to_token = Attention::new(
|
||||||
|
embedding_dim,
|
||||||
|
num_heads,
|
||||||
|
2,
|
||||||
|
vb.pp("cross_attn_image_to_token"),
|
||||||
|
)?;
|
||||||
|
// TODO: use relu in this mlp
|
||||||
|
let mlp = crate::MlpBlock::new(embedding_dim, mlp_dim, vb.pp("mlp"))?;
|
||||||
|
Ok(Self {
|
||||||
|
self_attn,
|
||||||
|
norm1,
|
||||||
|
cross_attn_image_to_token,
|
||||||
|
norm2,
|
||||||
|
mlp,
|
||||||
|
norm3,
|
||||||
|
norm4,
|
||||||
|
cross_attn_token_to_image,
|
||||||
|
skip_first_layer_pe,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(
|
||||||
|
&self,
|
||||||
|
queries: &Tensor,
|
||||||
|
keys: &Tensor,
|
||||||
|
query_pe: &Tensor,
|
||||||
|
key_pe: &Tensor,
|
||||||
|
) -> Result<(Tensor, Tensor)> {
|
||||||
|
// Self attention block
|
||||||
|
let queries = if self.skip_first_layer_pe {
|
||||||
|
self.self_attn.forward(queries, keys, queries)?
|
||||||
|
} else {
|
||||||
|
let q = (queries + query_pe)?;
|
||||||
|
let attn_out = self.self_attn.forward(&q, &q, queries)?;
|
||||||
|
(queries + attn_out)?
|
||||||
|
};
|
||||||
|
let queries = self.norm1.forward(&queries)?;
|
||||||
|
|
||||||
|
// Cross attention block, tokens attending to image embedding
|
||||||
|
let q = (&queries + query_pe)?;
|
||||||
|
let k = (keys + key_pe)?;
|
||||||
|
let attn_out = self.cross_attn_token_to_image.forward(&q, &k, keys)?;
|
||||||
|
let queries = (&queries + attn_out)?;
|
||||||
|
let queries = self.norm2.forward(&queries)?;
|
||||||
|
|
||||||
|
// MLP block
|
||||||
|
let mlp_out = self.mlp.forward(&queries);
|
||||||
|
let queries = (queries + mlp_out)?;
|
||||||
|
let queries = self.norm3.forward(&queries)?;
|
||||||
|
|
||||||
|
// Cross attention block, image embedding attending to tokens
|
||||||
|
let q = (&queries + query_pe)?;
|
||||||
|
let k = (keys + key_pe)?;
|
||||||
|
let attn_out = self.cross_attn_image_to_token.forward(&k, &q, &queries)?;
|
||||||
|
let keys = (keys + attn_out)?;
|
||||||
|
let keys = self.norm4.forward(&keys)?;
|
||||||
|
|
||||||
|
Ok((queries, keys))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct TwoWayTransformer {
|
||||||
|
layers: Vec<TwoWayAttentionBlock>,
|
||||||
|
final_attn_token_to_image: Attention,
|
||||||
|
norm_final_attn: LayerNorm,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TwoWayTransformer {
|
||||||
|
pub fn new(
|
||||||
|
depth: usize,
|
||||||
|
embedding_dim: usize,
|
||||||
|
num_heads: usize,
|
||||||
|
mlp_dim: usize,
|
||||||
|
vb: VarBuilder,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let vb_l = vb.pp("layers");
|
||||||
|
let mut layers = Vec::with_capacity(depth);
|
||||||
|
for i in 0..depth {
|
||||||
|
let layer =
|
||||||
|
TwoWayAttentionBlock::new(embedding_dim, num_heads, mlp_dim, i == 0, vb_l.pp(i))?;
|
||||||
|
layers.push(layer)
|
||||||
|
}
|
||||||
|
let final_attn_token_to_image = Attention::new(
|
||||||
|
embedding_dim,
|
||||||
|
num_heads,
|
||||||
|
2,
|
||||||
|
vb.pp("final_attn_token_to_image"),
|
||||||
|
)?;
|
||||||
|
let norm_final_attn = layer_norm(embedding_dim, 1e-5, vb.pp("norm_final_attn"))?;
|
||||||
|
Ok(Self {
|
||||||
|
layers,
|
||||||
|
final_attn_token_to_image,
|
||||||
|
norm_final_attn,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(
|
||||||
|
&self,
|
||||||
|
image_embedding: &Tensor,
|
||||||
|
image_pe: &Tensor,
|
||||||
|
point_embedding: &Tensor,
|
||||||
|
) -> Result<(Tensor, Tensor)> {
|
||||||
|
let (bs, c, h, w) = image_embedding.dims4()?;
|
||||||
|
let image_embedding = image_embedding.flatten_from(2)?.permute((0, 2, 1))?;
|
||||||
|
let image_pe = image_pe.flatten_from(2)?.permute((0, 2, 1))?;
|
||||||
|
|
||||||
|
let mut queries = point_embedding.clone();
|
||||||
|
let mut keys = image_embedding;
|
||||||
|
|
||||||
|
for layer in self.layers.iter() {
|
||||||
|
(queries, keys) = layer.forward(&queries, &keys, point_embedding, &image_pe)?
|
||||||
|
}
|
||||||
|
|
||||||
|
let q = (&queries + point_embedding)?;
|
||||||
|
let k = (&keys + image_pe)?;
|
||||||
|
let attn_out = self.final_attn_token_to_image.forward(&q, &k, &keys)?;
|
||||||
|
let queries = (queries + attn_out)?.apply(&self.norm_final_attn)?;
|
||||||
|
|
||||||
|
Ok((queries, keys))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user