Move more models to candle-transformers (#796)

* Move dinov2.

* Move efficientnet.

* Move the quantized llama model.

* Move segment-anything.
This commit is contained in:
Laurent Mazare
2023-09-10 10:20:18 +01:00
committed by GitHub
parent d3f05eae8c
commit 35f72514f5
21 changed files with 773 additions and 759 deletions

View File

@ -7,108 +7,11 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
pub mod model_image_encoder;
pub mod model_mask_decoder;
pub mod model_prompt_encoder;
pub mod model_sam;
pub mod model_tiny_vit;
pub mod model_transformer;
use candle::{DType, Result, Tensor};
use candle_nn::{Module, VarBuilder};
use candle::DType;
use candle_nn::VarBuilder;
use candle_transformers::models::segment_anything::sam;
use clap::Parser;
pub fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> {
let inner = if bias {
candle_nn::linear(in_dim, out_dim, vb)?
} else {
candle_nn::linear_no_bias(in_dim, out_dim, vb)?
};
let span = tracing::span!(tracing::Level::TRACE, "linear");
Ok(Linear { inner, span })
}
#[derive(Debug)]
pub struct LayerNorm2d {
weight: Tensor,
bias: Tensor,
num_channels: usize,
eps: f64,
}
impl LayerNorm2d {
pub fn new(num_channels: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
let weight = vb.get(num_channels, "weight")?;
let bias = vb.get(num_channels, "bias")?;
Ok(Self {
weight,
bias,
num_channels,
eps,
})
}
}
impl Module for LayerNorm2d {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let u = xs.mean_keepdim(1)?;
let xs = xs.broadcast_sub(&u)?;
let s = xs.sqr()?.mean_keepdim(1)?;
let xs = xs.broadcast_div(&(s + self.eps)?.sqrt()?)?;
xs.broadcast_mul(&self.weight.reshape((1, self.num_channels, 1, 1))?)?
.broadcast_add(&self.bias.reshape((1, self.num_channels, 1, 1))?)
}
}
#[derive(Debug)]
pub struct MlpBlock {
lin1: Linear,
lin2: Linear,
activation: candle_nn::Activation,
span: tracing::Span,
}
impl MlpBlock {
pub fn new(
embedding_dim: usize,
mlp_dim: usize,
activation: candle_nn::Activation,
vb: VarBuilder,
) -> Result<Self> {
let lin1 = linear(vb.pp("lin1"), embedding_dim, mlp_dim, true)?;
let lin2 = linear(vb.pp("lin2"), mlp_dim, embedding_dim, true)?;
let span = tracing::span!(tracing::Level::TRACE, "mlp-block");
Ok(Self {
lin1,
lin2,
activation,
span,
})
}
}
impl Module for MlpBlock {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
xs.apply(&self.lin1)?
.apply(&self.activation)?
.apply(&self.lin2)
}
}
#[derive(Debug)]
pub struct Linear {
inner: candle_nn::Linear,
span: tracing::Span,
}
impl Module for Linear {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
self.inner.forward(x)
}
}
#[derive(Parser)]
struct Args {
#[arg(long)]
@ -173,7 +76,7 @@ pub fn main() -> anyhow::Result<()> {
let (_c, h, w) = image.dims3()?;
(image, h, w)
} else {
let (image, h, w) = candle_examples::load_image(&args.image, Some(model_sam::IMAGE_SIZE))?;
let (image, h, w) = candle_examples::load_image(&args.image, Some(sam::IMAGE_SIZE))?;
(image.to_device(&device)?, h, w)
};
println!("loaded image {image:?}");
@ -195,9 +98,9 @@ pub fn main() -> anyhow::Result<()> {
let weights = weights.deserialize()?;
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
let sam = if args.use_tiny {
model_sam::Sam::new_tiny(vb)? // tiny vit_t
sam::Sam::new_tiny(vb)? // tiny vit_t
} else {
model_sam::Sam::new(768, 12, 12, &[2, 5, 8, 11], vb)? // sam_vit_b
sam::Sam::new(768, 12, 12, &[2, 5, 8, 11], vb)? // sam_vit_b
};
if args.generate_masks {