mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 04:00:28 +00:00
Add FastViT model. (#2444)
This commit is contained in:
512
candle-transformers/src/models/fastvit.rs
Normal file
512
candle-transformers/src/models/fastvit.rs
Normal file
@ -0,0 +1,512 @@
|
||||
//! FastViT inference implementation based on timm
|
||||
//!
|
||||
//! See "FastViT: A Fast Hybrid Vision Transformer using Structural Reparameterization"
|
||||
//! https://arxiv.org/pdf/2303.14189
|
||||
//!
|
||||
//! https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/fastvit.py
|
||||
|
||||
use candle::{DType, Result, Tensor, D};
|
||||
use candle_nn::{
|
||||
batch_norm, conv2d, conv2d_no_bias, linear, linear_no_bias, ops::sigmoid, ops::softmax,
|
||||
BatchNorm, Conv2d, Conv2dConfig, Func, VarBuilder,
|
||||
};
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Config {
|
||||
exp_ratio: usize,
|
||||
in_channels: usize,
|
||||
blocks: [usize; 4],
|
||||
attn: bool,
|
||||
lkc_use_act: bool,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn t8() -> Self {
|
||||
Self {
|
||||
exp_ratio: 3,
|
||||
in_channels: 48,
|
||||
blocks: [2, 2, 4, 2],
|
||||
attn: false,
|
||||
lkc_use_act: false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn t12() -> Self {
|
||||
Self {
|
||||
exp_ratio: 3,
|
||||
in_channels: 64,
|
||||
blocks: [2, 2, 6, 2],
|
||||
attn: false,
|
||||
lkc_use_act: false,
|
||||
}
|
||||
}
|
||||
pub fn s12() -> Self {
|
||||
Self {
|
||||
exp_ratio: 4,
|
||||
in_channels: 64,
|
||||
blocks: [2, 2, 6, 2],
|
||||
attn: false,
|
||||
lkc_use_act: false,
|
||||
}
|
||||
}
|
||||
pub fn sa12() -> Self {
|
||||
Self {
|
||||
exp_ratio: 4,
|
||||
in_channels: 64,
|
||||
blocks: [2, 2, 6, 2],
|
||||
attn: true,
|
||||
lkc_use_act: false,
|
||||
}
|
||||
}
|
||||
pub fn sa24() -> Self {
|
||||
Self {
|
||||
exp_ratio: 4,
|
||||
in_channels: 64,
|
||||
blocks: [4, 4, 12, 4],
|
||||
attn: true,
|
||||
lkc_use_act: false,
|
||||
}
|
||||
}
|
||||
pub fn sa36() -> Self {
|
||||
Self {
|
||||
exp_ratio: 4,
|
||||
in_channels: 64,
|
||||
blocks: [6, 6, 18, 6],
|
||||
attn: true,
|
||||
lkc_use_act: false,
|
||||
}
|
||||
}
|
||||
pub fn ma36() -> Self {
|
||||
Self {
|
||||
exp_ratio: 4,
|
||||
in_channels: 76,
|
||||
blocks: [6, 6, 18, 6],
|
||||
attn: true,
|
||||
lkc_use_act: false,
|
||||
}
|
||||
}
|
||||
|
||||
// configs used by MobileCLIP's image encoder
|
||||
pub fn mci0() -> Self {
|
||||
Self {
|
||||
exp_ratio: 3,
|
||||
in_channels: 64,
|
||||
blocks: [2, 6, 10, 2],
|
||||
attn: true,
|
||||
lkc_use_act: true,
|
||||
}
|
||||
}
|
||||
pub fn mci1() -> Self {
|
||||
Self {
|
||||
exp_ratio: 3,
|
||||
in_channels: 64,
|
||||
blocks: [4, 12, 20, 4],
|
||||
attn: true,
|
||||
lkc_use_act: true,
|
||||
}
|
||||
}
|
||||
pub fn mci2() -> Self {
|
||||
Self {
|
||||
exp_ratio: 3,
|
||||
in_channels: 80,
|
||||
blocks: [4, 12, 24, 4],
|
||||
attn: true,
|
||||
lkc_use_act: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn conv_norm(
|
||||
in_channels: usize,
|
||||
out_channels: usize,
|
||||
kernel: usize,
|
||||
stride: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Func<'static>> {
|
||||
let conv2d_cfg = Conv2dConfig {
|
||||
stride,
|
||||
padding: kernel / 2,
|
||||
groups: in_channels,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let bn = batch_norm(out_channels, 1e-5, vb.pp("bn"))?;
|
||||
let conv = conv2d_no_bias(in_channels, out_channels, kernel, conv2d_cfg, vb.pp("conv"))?;
|
||||
let conv = conv.absorb_bn(&bn)?;
|
||||
Ok(Func::new(move |xs| {
|
||||
let xs = xs.apply(&conv)?;
|
||||
Ok(xs)
|
||||
}))
|
||||
}
|
||||
|
||||
fn conv_mlp(dim: usize, exp_ratio: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
let conv2d_cfg = Conv2dConfig {
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let conv = conv_norm(dim, dim, 7, 1, vb.pp("conv"))?;
|
||||
let fc1 = conv2d(dim, dim * exp_ratio, 1, conv2d_cfg, vb.pp("fc1"))?;
|
||||
let fc2 = conv2d(dim * exp_ratio, dim, 1, conv2d_cfg, vb.pp("fc2"))?;
|
||||
|
||||
Ok(Func::new(move |xs| {
|
||||
let xs = xs.apply(&conv)?.apply(&fc1)?.gelu_erf()?.apply(&fc2)?;
|
||||
Ok(xs)
|
||||
}))
|
||||
}
|
||||
|
||||
fn squeeze_and_excitation(
|
||||
in_channels: usize,
|
||||
squeeze_channels: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Func<'static>> {
|
||||
let conv2d_cfg = Conv2dConfig {
|
||||
..Default::default()
|
||||
};
|
||||
let fc1 = conv2d(in_channels, squeeze_channels, 1, conv2d_cfg, vb.pp("fc1"))?;
|
||||
let fc2 = conv2d(squeeze_channels, in_channels, 1, conv2d_cfg, vb.pp("fc2"))?;
|
||||
|
||||
Ok(Func::new(move |xs| {
|
||||
let residual = xs;
|
||||
let xs = xs.mean_keepdim(D::Minus2)?.mean_keepdim(D::Minus1)?;
|
||||
let xs = sigmoid(&xs.apply(&fc1)?.relu()?.apply(&fc2)?)?;
|
||||
|
||||
residual.broadcast_mul(&xs)
|
||||
}))
|
||||
}
|
||||
|
||||
// fuses a convolutional kernel and a batchnorm layer into a convolutional layer
|
||||
// based on the _fuse_bn_tensor method in timm
|
||||
// see https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/byobnet.py#L602
|
||||
fn fuse_conv_bn(weights: &Tensor, bn: BatchNorm) -> Result<(Tensor, Tensor)> {
|
||||
let (gamma, beta) = bn.weight_and_bias().unwrap();
|
||||
let mu = bn.running_mean();
|
||||
let sigma = (bn.running_var() + bn.eps())?.sqrt();
|
||||
let gps = (gamma / sigma)?;
|
||||
let bias = (beta - mu * &gps)?;
|
||||
let weights = weights.broadcast_mul(&gps.reshape(((), 1, 1, 1))?)?;
|
||||
|
||||
Ok((weights, bias))
|
||||
}
|
||||
|
||||
fn mobileone_block(
|
||||
in_channels: usize,
|
||||
out_channels: usize,
|
||||
kernel: usize,
|
||||
stride: usize,
|
||||
group_size: usize,
|
||||
use_act: bool,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Func<'static>> {
|
||||
let groups = if group_size == 0 {
|
||||
1
|
||||
} else {
|
||||
in_channels / group_size
|
||||
};
|
||||
|
||||
let padding = kernel / 2;
|
||||
let conv2d_cfg = Conv2dConfig {
|
||||
stride,
|
||||
groups,
|
||||
padding,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut w = Tensor::zeros(
|
||||
(out_channels, in_channels / groups, kernel, kernel),
|
||||
DType::F32,
|
||||
vb.device(),
|
||||
)?;
|
||||
let dim = out_channels;
|
||||
|
||||
let mut b = Tensor::zeros(dim, DType::F32, vb.device())?;
|
||||
|
||||
let conv_kxk_bn = batch_norm(dim, 1e-5, vb.pp("conv_kxk.0.bn"));
|
||||
let conv_kxk = conv2d_no_bias(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel,
|
||||
conv2d_cfg,
|
||||
vb.pp("conv_kxk.0.conv"),
|
||||
);
|
||||
|
||||
if let (Ok(conv), Ok(bn)) = (conv_kxk, conv_kxk_bn) {
|
||||
let (wk, bk) = fuse_conv_bn(conv.weight(), bn)?;
|
||||
w = (w + wk)?;
|
||||
b = (b + bk)?;
|
||||
};
|
||||
|
||||
let conv_scale_bn = batch_norm(dim, 1e-5, vb.pp("conv_scale.bn"));
|
||||
let conv_scale = conv2d_no_bias(
|
||||
in_channels,
|
||||
out_channels,
|
||||
1,
|
||||
conv2d_cfg,
|
||||
vb.pp("conv_scale.conv"),
|
||||
);
|
||||
|
||||
if let (Ok(conv), Ok(bn)) = (conv_scale, conv_scale_bn) {
|
||||
let (ws, bs) = fuse_conv_bn(conv.weight(), bn)?;
|
||||
// pad to 3x3
|
||||
let ws = ws
|
||||
.pad_with_zeros(D::Minus1, 1, 1)?
|
||||
.pad_with_zeros(D::Minus2, 1, 1)?;
|
||||
|
||||
w = (w + ws)?;
|
||||
b = (b + bs)?;
|
||||
};
|
||||
|
||||
let se = squeeze_and_excitation(out_channels, out_channels / 16, vb.pp("se"));
|
||||
|
||||
// read and reparameterize the identity bn into wi and bi
|
||||
let identity_bn = batch_norm(dim, 1e-5, vb.pp("identity"));
|
||||
|
||||
if let Ok(id_bn) = identity_bn {
|
||||
let mut weights: Vec<f32> = vec![0.0; w.elem_count()];
|
||||
let id = in_channels / groups;
|
||||
// See https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/byobnet.py#L809
|
||||
for i in 0..in_channels {
|
||||
if kernel > 1 {
|
||||
weights[i * kernel * kernel + 4] = 1.0;
|
||||
} else {
|
||||
weights[i * (id + 1)] = 1.0;
|
||||
}
|
||||
}
|
||||
|
||||
let weights = &Tensor::from_vec(weights, w.shape(), w.device())?;
|
||||
let (wi, bi) = fuse_conv_bn(weights, id_bn)?;
|
||||
|
||||
w = (w + wi)?;
|
||||
b = (b + bi)?;
|
||||
};
|
||||
let reparam_conv = Conv2d::new(w, Some(b), conv2d_cfg);
|
||||
|
||||
Ok(Func::new(move |xs| {
|
||||
let mut xs = xs.apply(&reparam_conv)?;
|
||||
if let Ok(f) = &se {
|
||||
xs = xs.apply(f)?;
|
||||
}
|
||||
if use_act {
|
||||
xs = xs.gelu_erf()?;
|
||||
};
|
||||
Ok(xs)
|
||||
}))
|
||||
}
|
||||
|
||||
fn repmixer(dim: usize, kernel: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
let gamma = vb.get((dim, 1, 1), "layer_scale.gamma")?;
|
||||
let norm = mobileone_block(dim, dim, kernel, 1, 1, false, vb.pp("norm"))?;
|
||||
let mixer = mobileone_block(dim, dim, kernel, 1, 1, false, vb.pp("mixer"))?;
|
||||
|
||||
Ok(Func::new(move |xs| {
|
||||
let residual = xs.clone();
|
||||
let xs = (xs.apply(&mixer)? - xs.apply(&norm)?)?;
|
||||
let xs = xs.broadcast_mul(&gamma.reshape((1, (), 1, 1))?)?;
|
||||
let xs = (xs + residual)?;
|
||||
Ok(xs)
|
||||
}))
|
||||
}
|
||||
|
||||
fn repmixer_block(dim: usize, exp_ratio: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
let gamma = vb.get((dim, 1, 1), "layer_scale.gamma")?;
|
||||
let token_mixer = repmixer(dim, 3, vb.pp("token_mixer"))?;
|
||||
let mlp = conv_mlp(dim, exp_ratio, vb.pp("mlp"))?;
|
||||
|
||||
Ok(Func::new(move |xs| {
|
||||
let residual = xs.apply(&token_mixer)?;
|
||||
let mut xs = residual.apply(&mlp)?;
|
||||
xs = xs.broadcast_mul(&gamma.reshape((1, (), 1, 1))?)?;
|
||||
let xs = (xs + residual)?;
|
||||
Ok(xs)
|
||||
}))
|
||||
}
|
||||
|
||||
fn positional_encoding(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
let conv2d_cfg = Conv2dConfig {
|
||||
stride: 1,
|
||||
padding: 3,
|
||||
groups: dim,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let conv = conv2d(dim, dim, 7, conv2d_cfg, vb.pp("pos_enc"))?;
|
||||
|
||||
Ok(Func::new(move |xs| {
|
||||
let xs = (xs + xs.apply(&conv)?)?;
|
||||
Ok(xs)
|
||||
}))
|
||||
}
|
||||
|
||||
fn attention(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
let qkv = linear_no_bias(dim, dim * 3, vb.pp("qkv"))?;
|
||||
let proj = linear(dim, dim, vb.pp("proj"))?;
|
||||
let num_heads = 32;
|
||||
let head_dim = dim / num_heads;
|
||||
let scale = (head_dim as f64).powf(-0.5);
|
||||
|
||||
Ok(Func::new(move |xs| {
|
||||
let xs = xs.clone();
|
||||
let (b, c, h, w) = xs.dims4()?;
|
||||
let n = h * w;
|
||||
let xs = xs.flatten_from(2)?.transpose(D::Minus1, D::Minus2)?;
|
||||
let qkv = xs
|
||||
.apply(&qkv)?
|
||||
.reshape((b, n, 3, num_heads, head_dim))?
|
||||
.permute((2, 0, 3, 1, 4))?;
|
||||
|
||||
let q = qkv.get(0)?;
|
||||
let k = qkv.get(1)?;
|
||||
let v = qkv.get(2)?;
|
||||
|
||||
let q = (q * scale)?;
|
||||
|
||||
let att = q.matmul(&k.transpose(D::Minus2, D::Minus1)?)?;
|
||||
let att = softmax(&att, D::Minus1)?;
|
||||
let xs = att.matmul(&v)?;
|
||||
|
||||
let xs = xs.transpose(1, 2)?.reshape((b, n, c))?;
|
||||
let xs = xs.apply(&proj)?;
|
||||
let xs = xs.transpose(D::Minus1, D::Minus2)?.reshape((b, c, h, w))?;
|
||||
|
||||
Ok(xs)
|
||||
}))
|
||||
}
|
||||
|
||||
fn attention_block(dim: usize, exp_ratio: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
let gamma1 = vb.get((dim, 1, 1), "layer_scale_1.gamma")?;
|
||||
let gamma2 = vb.get((dim, 1, 1), "layer_scale_2.gamma")?;
|
||||
let norm = batch_norm(dim, 1e-5, vb.pp("norm"))?;
|
||||
let token_mixer = attention(dim, vb.pp("token_mixer"))?;
|
||||
let mlp = conv_mlp(dim, exp_ratio, vb.pp("mlp"))?;
|
||||
|
||||
Ok(Func::new(move |xs| {
|
||||
let xs = xs.clone();
|
||||
let xs = (&xs
|
||||
+ &xs
|
||||
.apply_t(&norm, false)?
|
||||
.apply(&token_mixer)?
|
||||
.broadcast_mul(&gamma1.reshape((1, (), 1, 1))?)?)?;
|
||||
|
||||
let xs = (&xs
|
||||
+ &xs
|
||||
.apply(&mlp)?
|
||||
.broadcast_mul(&gamma2.reshape((1, (), 1, 1))?)?)?;
|
||||
|
||||
Ok(xs)
|
||||
}))
|
||||
}
|
||||
|
||||
fn fastvit_stage(cfg: &Config, idx: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
let nblocks = cfg.blocks[idx];
|
||||
let mut blocks = Vec::with_capacity(nblocks);
|
||||
|
||||
let dim = cfg.in_channels << idx;
|
||||
let downsample = fastvit_patch_embed(dim / 2, dim, cfg.lkc_use_act, vb.pp("downsample"));
|
||||
for block_idx in 0..nblocks {
|
||||
let block = if cfg.attn && idx == 3 {
|
||||
attention_block(dim, cfg.exp_ratio, vb.pp(format!("blocks.{block_idx}")))?
|
||||
} else {
|
||||
repmixer_block(dim, cfg.exp_ratio, vb.pp(format!("blocks.{block_idx}")))?
|
||||
};
|
||||
blocks.push(block);
|
||||
}
|
||||
let pos_emb = positional_encoding(dim, vb.pp("pos_emb"));
|
||||
|
||||
Ok(Func::new(move |xs| {
|
||||
let mut xs = xs.clone();
|
||||
if let Ok(ds) = &downsample {
|
||||
xs = xs.apply(ds)?;
|
||||
}
|
||||
if let Ok(pos) = &pos_emb {
|
||||
xs = xs.apply(pos)?;
|
||||
}
|
||||
for block in blocks.iter() {
|
||||
xs = xs.apply(block)?;
|
||||
}
|
||||
Ok(xs)
|
||||
}))
|
||||
}
|
||||
|
||||
fn fastvit_patch_embed(
|
||||
in_channels: usize,
|
||||
out_channels: usize,
|
||||
use_act: bool,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Func<'static>> {
|
||||
let lk = conv_norm(in_channels, out_channels, 7, 2, vb.pp("proj.0.large_conv"))?;
|
||||
let sk = conv_norm(in_channels, out_channels, 3, 2, vb.pp("proj.0.small_conv"))?;
|
||||
let se = squeeze_and_excitation(out_channels, out_channels / 16, vb.pp("proj.0.se"));
|
||||
let mb = mobileone_block(out_channels, out_channels, 1, 1, 0, true, vb.pp("proj.1"))?;
|
||||
|
||||
Ok(Func::new(move |xs| {
|
||||
let mut xs = (xs.apply(&lk)? + xs.apply(&sk)?)?;
|
||||
if let Ok(f) = &se {
|
||||
xs = xs.apply(f)?;
|
||||
}
|
||||
if use_act {
|
||||
xs = xs.gelu_erf()?;
|
||||
};
|
||||
let xs = xs.apply(&mb)?;
|
||||
Ok(xs)
|
||||
}))
|
||||
}
|
||||
|
||||
fn fastvit_stem(in_channels: usize, out_channels: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
let mb0 = mobileone_block(in_channels, out_channels, 3, 2, 0, true, vb.pp(0))?;
|
||||
let mb1 = mobileone_block(out_channels, out_channels, 3, 2, 1, true, vb.pp(1))?;
|
||||
let mb2 = mobileone_block(out_channels, out_channels, 1, 1, 0, true, vb.pp(2))?;
|
||||
Ok(Func::new(move |xs| {
|
||||
let xs = xs.apply(&mb0)?.apply(&mb1)?.apply(&mb2)?;
|
||||
Ok(xs)
|
||||
}))
|
||||
}
|
||||
|
||||
// Build a fastvit model for a given configuration.
|
||||
fn fastvit_model(cfg: &Config, nclasses: Option<usize>, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
let cls = match nclasses {
|
||||
None => None,
|
||||
Some(nclasses) => {
|
||||
let linear = linear(cfg.in_channels * 16, nclasses, vb.pp("head.fc"))?;
|
||||
Some(linear)
|
||||
}
|
||||
};
|
||||
|
||||
let stem = fastvit_stem(3, cfg.in_channels, vb.pp("stem"))?;
|
||||
let final_conv = mobileone_block(
|
||||
cfg.in_channels * 8,
|
||||
cfg.in_channels * 16,
|
||||
3,
|
||||
1,
|
||||
1,
|
||||
true,
|
||||
vb.pp("final_conv"),
|
||||
)?;
|
||||
|
||||
let vb = vb.pp("stages");
|
||||
let stage1 = fastvit_stage(cfg, 0, vb.pp(0))?;
|
||||
let stage2 = fastvit_stage(cfg, 1, vb.pp(1))?;
|
||||
let stage3 = fastvit_stage(cfg, 2, vb.pp(2))?;
|
||||
let stage4 = fastvit_stage(cfg, 3, vb.pp(3))?;
|
||||
|
||||
Ok(Func::new(move |xs| {
|
||||
let xs = xs
|
||||
.apply(&stem)?
|
||||
.apply(&stage1)?
|
||||
.apply(&stage2)?
|
||||
.apply(&stage3)?
|
||||
.apply(&stage4)?
|
||||
.apply(&final_conv)?;
|
||||
|
||||
match &cls {
|
||||
None => Ok(xs),
|
||||
Some(cls) => xs.mean(D::Minus2)?.mean(D::Minus1)?.apply(cls),
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
pub fn fastvit(cfg: &Config, nclasses: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
fastvit_model(cfg, Some(nclasses), vb)
|
||||
}
|
||||
|
||||
pub fn fastvit_no_final_layer(cfg: &Config, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
fastvit_model(cfg, None, vb)
|
||||
}
|
@ -19,6 +19,7 @@ pub mod efficientvit;
|
||||
pub mod encodec;
|
||||
pub mod eva2;
|
||||
pub mod falcon;
|
||||
pub mod fastvit;
|
||||
pub mod flux;
|
||||
pub mod gemma;
|
||||
pub mod gemma2;
|
||||
|
Reference in New Issue
Block a user