mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00

* More segment-anything again. * Transformer block forward. * Two-ways transformer. * Position embeddings. * Sketch the prompt encoder. * More prompt-encoder. * More prompt-encoder. * Add the main sam module. * Embed the transformer. * And hook the transformer forward step. * Build the model. * Handle the global attn indexes. * Get the model to load.
112 lines
3.4 KiB
Rust
112 lines
3.4 KiB
Rust
//! SAM: Segment Anything Model
|
|
//! https://github.com/facebookresearch/segment-anything
|
|
#![allow(unused)]
|
|
|
|
#[cfg(feature = "mkl")]
|
|
extern crate intel_mkl_src;
|
|
|
|
#[cfg(feature = "accelerate")]
|
|
extern crate accelerate_src;
|
|
|
|
pub mod model_image_encoder;
|
|
pub mod model_mask_decoder;
|
|
pub mod model_prompt_encoder;
|
|
pub mod model_sam;
|
|
pub mod model_transformer;
|
|
|
|
use candle::{DType, IndexOp, Result, Tensor, D};
|
|
use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
|
|
use clap::Parser;
|
|
|
|
pub fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> {
|
|
if bias {
|
|
candle_nn::linear(in_dim, out_dim, vb)
|
|
} else {
|
|
candle_nn::linear_no_bias(in_dim, out_dim, vb)
|
|
}
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub struct MlpBlock {
|
|
lin1: Linear,
|
|
lin2: Linear,
|
|
}
|
|
|
|
impl MlpBlock {
|
|
pub fn new(embedding_dim: usize, mlp_dim: usize, vb: VarBuilder) -> Result<Self> {
|
|
let lin1 = candle_nn::linear(embedding_dim, mlp_dim, vb.pp("lin1"))?;
|
|
let lin2 = candle_nn::linear(mlp_dim, embedding_dim, vb.pp("lin2"))?;
|
|
Ok(Self { lin1, lin2 })
|
|
}
|
|
}
|
|
|
|
impl Module for MlpBlock {
|
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
|
xs.apply(&self.lin1)?.gelu()?.apply(&self.lin2)
|
|
}
|
|
}
|
|
|
|
/*
|
|
fn interpolate_pos_encoding(&self, xs: &Tensor, w: usize, h: usize) -> Result<Tensor> {
|
|
let npatch = xs.dim(1)? - 1;
|
|
let n = self.pos_embed.dim(1)? - 1;
|
|
let sqrt_n = (n as f64).sqrt();
|
|
if npatch == n && w == h {
|
|
return Ok(xs.clone());
|
|
}
|
|
let class_pos_embed = self.pos_embed.i((.., ..1))?;
|
|
let patch_pos_embed = self.pos_embed.i((.., 1..))?;
|
|
let dim = xs.dim(D::Minus1)?;
|
|
let (w0, h0) = ((w / PATCH_SIZE) as f64 + 0.1, (h / PATCH_SIZE) as f64 + 0.1);
|
|
let patch_pos_embed = patch_pos_embed
|
|
.reshape((1, sqrt_n as usize, sqrt_n as usize, dim))?
|
|
.transpose(2, 3)?
|
|
.transpose(1, 2)?;
|
|
// This uses bicubic interpolation in the original implementation.
|
|
let patch_pos_embed = patch_pos_embed.upsample_nearest2d(h0 as usize, w0 as usize)?;
|
|
let el_count = patch_pos_embed.shape().elem_count();
|
|
let patch_pos_embed =
|
|
patch_pos_embed
|
|
.transpose(1, 2)?
|
|
.transpose(2, 3)?
|
|
.reshape((1, el_count / dim, dim))?;
|
|
Tensor::cat(&[&class_pos_embed, &patch_pos_embed], 1)
|
|
}
|
|
|
|
fn prepare_tokens_with_mask(&self, xs: &Tensor) -> Result<Tensor> {
|
|
let (_b, _nc, w, h) = xs.dims4()?;
|
|
let xs = self.patch_embed.forward(xs)?;
|
|
let xs = Tensor::cat(&[&self.cls_token, &xs], 1)?;
|
|
&xs + &self.interpolate_pos_encoding(&xs, w, h)?
|
|
}
|
|
*/
|
|
|
|
#[derive(Parser)]
|
|
struct Args {
|
|
#[arg(long)]
|
|
model: String,
|
|
|
|
#[arg(long)]
|
|
image: String,
|
|
|
|
/// Run on CPU rather than on GPU.
|
|
#[arg(long)]
|
|
cpu: bool,
|
|
}
|
|
|
|
pub fn main() -> anyhow::Result<()> {
|
|
let args = Args::parse();
|
|
|
|
let device = candle_examples::device(args.cpu)?;
|
|
|
|
let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device);
|
|
println!("loaded image {image:?}");
|
|
|
|
let weights = unsafe { candle::safetensors::MmapedFile::new(args.model)? };
|
|
let weights = weights.deserialize()?;
|
|
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
|
|
let _sam = model_sam::Sam::new(768, 12, 12, &[2, 5, 8, 11], vb)?; // sam_vit_b
|
|
|
|
Ok(())
|
|
}
|