mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
More segment-anything again. (#764)
* More segment-anything again. * Transformer block forward. * Two-ways transformer. * Position embeddings. * Sketch the prompt encoder. * More prompt-encoder. * More prompt-encoder. * Add the main sam module. * Embed the transformer. * And hook the transformer forward step. * Build the model. * Handle the global attn indexes. * Get the model to load.
This commit is contained in:
@ -1,6 +1,8 @@
|
||||
use candle::{DType, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
|
||||
|
||||
use crate::model_transformer::TwoWayTransformer;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct MlpMaskDecoder {
|
||||
layers: Vec<Linear>,
|
||||
@ -53,7 +55,7 @@ impl Module for MlpMaskDecoder {
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct MaskDecoder {
|
||||
pub struct MaskDecoder {
|
||||
iou_token: candle_nn::Embedding,
|
||||
mask_tokens: candle_nn::Embedding,
|
||||
iou_prediction_head: MlpMaskDecoder,
|
||||
@ -62,17 +64,18 @@ struct MaskDecoder {
|
||||
output_upscaling_conv2: candle_nn::ConvTranspose2d,
|
||||
num_mask_tokens: usize,
|
||||
output_hypernetworks_mlps: Vec<MlpMaskDecoder>,
|
||||
transformer: TwoWayTransformer,
|
||||
}
|
||||
|
||||
impl MaskDecoder {
|
||||
fn new(
|
||||
pub fn new(
|
||||
transformer_dim: usize,
|
||||
num_multimask_outputs: usize,
|
||||
iou_head_depth: usize,
|
||||
iou_head_hidden_dim: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let num_mask_tokens = num_multimask_outputs - 1;
|
||||
let num_mask_tokens = num_multimask_outputs + 1;
|
||||
let iou_prediction_head = MlpMaskDecoder::new(
|
||||
transformer_dim,
|
||||
iou_head_hidden_dim,
|
||||
@ -117,6 +120,13 @@ impl MaskDecoder {
|
||||
)?;
|
||||
output_hypernetworks_mlps.push(mlp)
|
||||
}
|
||||
let transformer = TwoWayTransformer::new(
|
||||
/* depth */ 2,
|
||||
/* embedding_dim */ transformer_dim,
|
||||
/* num_heads */ 8,
|
||||
/* mlp_dim */ 2048,
|
||||
vb.pp("transformer"),
|
||||
)?;
|
||||
Ok(Self {
|
||||
iou_token,
|
||||
mask_tokens,
|
||||
@ -126,6 +136,7 @@ impl MaskDecoder {
|
||||
output_upscaling_conv2,
|
||||
num_mask_tokens,
|
||||
output_hypernetworks_mlps,
|
||||
transformer,
|
||||
})
|
||||
}
|
||||
|
||||
@ -182,7 +193,7 @@ impl MaskDecoder {
|
||||
let (b, c, h, w) = src.dims4()?;
|
||||
|
||||
// Run the transformer
|
||||
let (hs, src) = run_transformer(&src, &pos_src, &tokens)?;
|
||||
let (hs, src) = self.transformer.forward(&src, &pos_src, &tokens)?;
|
||||
let iou_token_out = hs.i((.., 0))?;
|
||||
let mask_tokens_out = hs.i((.., 1, 1 + self.num_mask_tokens))?;
|
||||
|
||||
@ -216,7 +227,3 @@ impl MaskDecoder {
|
||||
fn repeat_interleave(_img: &Tensor, _repeats: usize, _dim: usize) -> Result<Tensor> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn run_transformer(_src: &Tensor, _pos: &Tensor, _tokens: &Tensor) -> Result<(Tensor, Tensor)> {
|
||||
todo!()
|
||||
}
|
||||
|
Reference in New Issue
Block a user