More segment-anything again. (#764)

* More segment-anything again.

* Transformer block forward.

* Two-ways transformer.

* Position embeddings.

* Sketch the prompt encoder.

* More prompt-encoder.

* More prompt-encoder.

* Add the main sam module.

* Embed the transformer.

* And hook the transformer forward step.

* Build the model.

* Handle the global attn indexes.

* Get the model to load.
This commit is contained in:
Laurent Mazare
2023-09-07 13:06:55 +02:00
committed by GitHub
parent 8c991df394
commit 7b50f3e106
6 changed files with 454 additions and 20 deletions

View File

@ -8,9 +8,11 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
mod model_image_encoder;
mod model_mask_decoder;
mod model_transformer;
pub mod model_image_encoder;
pub mod model_mask_decoder;
pub mod model_prompt_encoder;
pub mod model_sam;
pub mod model_transformer;
use candle::{DType, IndexOp, Result, Tensor, D};
use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
@ -82,7 +84,7 @@ impl Module for MlpBlock {
#[derive(Parser)]
struct Args {
#[arg(long)]
model: Option<String>,
model: String,
#[arg(long)]
image: String,
@ -95,10 +97,15 @@ struct Args {
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
let _device = candle_examples::device(args.cpu)?;
let device = candle_examples::device(args.cpu)?;
let image = candle_examples::imagenet::load_image224(args.image)?;
let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device);
println!("loaded image {image:?}");
let weights = unsafe { candle::safetensors::MmapedFile::new(args.model)? };
let weights = weights.deserialize()?;
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
let _sam = model_sam::Sam::new(768, 12, 12, &[2, 5, 8, 11], vb)?; // sam_vit_b
Ok(())
}