Add MobileOne model. (#1595)

* Add MobileOne model.

* Clippy fixes

* Remove a comment.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
This commit is contained in:
Jani Monoses
2024-01-16 07:34:16 +02:00
committed by GitHub
parent 7e3349d7c3
commit 5270224f40
4 changed files with 452 additions and 0 deletions

View File

@ -0,0 +1,22 @@
# candle-mobileone
[MobileOne: An Improved One millisecond Mobile Backbone](https://arxiv.org/abs/2206.04040).
This candle implementation uses a pre-trained MobileOne network for inference. The
classification head has been trained on the ImageNet dataset and returns the
probabilities for the top-5 classes.
## Running an example
```
$ cargo run --example mobileone --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg --which s2
loaded image Tensor[dims 3, 224, 224; f32]
model built
mountain bike, all-terrain bike, off-roader: 79.33%
bicycle-built-for-two, tandem bicycle, tandem: 15.32%
crash helmet : 2.58%
unicycle, monocycle : 1.70%
alp : 0.21%
```

View File

@ -0,0 +1,96 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use clap::{Parser, ValueEnum};
use candle::{DType, IndexOp, D};
use candle_nn::{Module, VarBuilder};
use candle_transformers::models::mobileone;
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Which {
S0,
S1,
S2,
S3,
S4,
}
impl Which {
fn model_filename(&self) -> String {
let name = match self {
Self::S0 => "s0",
Self::S1 => "s1",
Self::S2 => "s2",
Self::S3 => "s3",
Self::S4 => "s4",
};
format!("timm/mobileone_{}.apple_in1k", name)
}
fn config(&self) -> mobileone::Config {
match self {
Self::S0 => mobileone::Config::s0(),
Self::S1 => mobileone::Config::s1(),
Self::S2 => mobileone::Config::s2(),
Self::S3 => mobileone::Config::s3(),
Self::S4 => mobileone::Config::s4(),
}
}
}
#[derive(Parser)]
struct Args {
#[arg(long)]
model: Option<String>,
#[arg(long)]
image: String,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
#[arg(value_enum, long, default_value_t=Which::S0)]
which: Which,
}
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
let device = candle_examples::device(args.cpu)?;
let image = candle_examples::imagenet::load_image224(args.image)?;
println!("loaded image {image:?}");
let model_file = match args.model {
None => {
let model_name = args.which.model_filename();
let api = hf_hub::api::sync::Api::new()?;
let api = api.model(model_name);
api.get("model.safetensors")?
}
Some(model) => model.into(),
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
let model = mobileone::mobileone(&args.which.config(), 1000, vb)?;
println!("model built");
let logits = model.forward(&image.unsqueeze(0)?)?;
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
.i(0)?
.to_vec1::<f32>()?;
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
for &(category_idx, pr) in prs.iter().take(5) {
println!(
"{:24}: {:.2}%",
candle_examples::imagenet::CLASSES[category_idx],
100. * pr
);
}
Ok(())
}

View File

@ -0,0 +1,333 @@
//! MobileOne inference implementation based on timm and candle-repvgg
//!
//! See "MobileOne: An Improved One millisecond Mobile Backbone"
//! https://arxiv.org/abs/2206.04040
use candle::{DType, Result, Tensor, D};
use candle_nn::{
batch_norm, conv2d, conv2d_no_bias, linear, ops::sigmoid, BatchNorm, Conv2d, Conv2dConfig,
Func, VarBuilder,
};
struct StageConfig {
blocks: usize,
channels: usize,
}
// The architecture in the paper has 6 stages. The timm implementation uses an equivalent form
// by concatenating the 5th stage (starts with stride 1) to the previous one.
const STAGES: [StageConfig; 5] = [
StageConfig {
blocks: 1,
channels: 64,
},
StageConfig {
blocks: 2,
channels: 64,
},
StageConfig {
blocks: 8,
channels: 128,
},
StageConfig {
blocks: 10,
channels: 256,
},
StageConfig {
blocks: 1,
channels: 512,
},
];
#[derive(Clone)]
pub struct Config {
/// overparameterization factor
k: usize,
/// per-stage channel number multipliers
alphas: [f32; 5],
}
impl Config {
pub fn s0() -> Self {
Self {
k: 4,
alphas: [0.75, 0.75, 1.0, 1.0, 2.0],
}
}
pub fn s1() -> Self {
Self {
k: 1,
alphas: [1.5, 1.5, 1.5, 2.0, 2.5],
}
}
pub fn s2() -> Self {
Self {
k: 1,
alphas: [1.5, 1.5, 2.0, 2.5, 4.0],
}
}
pub fn s3() -> Self {
Self {
k: 1,
alphas: [2.0, 2.0, 2.5, 3.0, 4.0],
}
}
pub fn s4() -> Self {
Self {
k: 1,
alphas: [3.0, 3.0, 3.5, 3.5, 4.0],
}
}
}
// SE blocks are used in the last stages of the s4 variant.
fn squeeze_and_excitation(
in_channels: usize,
squeeze_channels: usize,
vb: VarBuilder,
) -> Result<Func<'static>> {
let conv2d_cfg = Conv2dConfig {
..Default::default()
};
let fc1 = conv2d(in_channels, squeeze_channels, 1, conv2d_cfg, vb.pp("fc1"))?;
let fc2 = conv2d(squeeze_channels, in_channels, 1, conv2d_cfg, vb.pp("fc2"))?;
Ok(Func::new(move |xs| {
let residual = xs;
let xs = xs.mean_keepdim(D::Minus2)?.mean_keepdim(D::Minus1)?;
let xs = sigmoid(&xs.apply(&fc1)?.relu()?.apply(&fc2)?)?;
residual.broadcast_mul(&xs)
}))
}
// fuses a convolutional kernel and a batchnorm layer into a convolutional layer
// based on the _fuse_bn_tensor method in timm
// see https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/byobnet.py#L602
fn fuse_conv_bn(weights: &Tensor, bn: BatchNorm) -> Result<(Tensor, Tensor)> {
let (gamma, beta) = bn.weight_and_bias().unwrap();
let mu = bn.running_mean();
let sigma = (bn.running_var() + bn.eps())?.sqrt();
let gps = (gamma / sigma)?;
let bias = (beta - mu * &gps)?;
let weights = weights.broadcast_mul(&gps.reshape(((), 1, 1, 1))?)?;
Ok((weights, bias))
}
// A mobileone block has a different training time and inference time architecture.
// The latter is a simple and efficient equivalent transformation of the former
// realized by a structural reparameterization technique, where convolutions
// along with identity branches and batchnorm layers are fused into a single convolution.
#[allow(clippy::too_many_arguments)]
fn mobileone_block(
has_identity: bool,
k: usize,
dim: usize,
stride: usize,
padding: usize,
groups: usize,
kernel: usize,
in_channels: usize,
out_channels: usize,
vb: VarBuilder,
) -> Result<Func<'static>> {
let conv2d_cfg = Conv2dConfig {
stride,
padding,
groups,
..Default::default()
};
let mut w = Tensor::zeros(
(out_channels, in_channels / groups, kernel, kernel),
DType::F32,
vb.device(),
)?;
let mut b = Tensor::zeros(dim, DType::F32, vb.device())?;
// k is the training-time overparameterization factor, larger than 1 only in the s0 variant
for i in 0..k {
let conv_kxk_bn = batch_norm(dim, 1e-5, vb.pp(format!("conv_kxk.{i}.bn")))?;
let conv_kxk = conv2d_no_bias(
in_channels,
out_channels,
kernel,
conv2d_cfg,
vb.pp(format!("conv_kxk.{i}.conv")),
)?;
let (wk, bk) = fuse_conv_bn(conv_kxk.weight(), conv_kxk_bn)?;
w = (w + wk)?;
b = (b + bk)?;
}
if kernel > 1 {
let conv_scale_bn = batch_norm(dim, 1e-5, vb.pp("conv_scale.bn"))?;
let conv_scale = conv2d_no_bias(
in_channels,
out_channels,
1,
conv2d_cfg,
vb.pp("conv_scale.conv"),
)?;
let (mut ws, bs) = fuse_conv_bn(conv_scale.weight(), conv_scale_bn)?;
// resize to 3x3
ws = ws.pad_with_zeros(D::Minus1, 1, 1)?;
ws = ws.pad_with_zeros(D::Minus2, 1, 1)?;
w = (w + ws)?;
b = (b + bs)?;
}
// Use SE blocks if present (last layers of the s4 variant)
let se = squeeze_and_excitation(out_channels, out_channels / 16, vb.pp("attn"));
// read and reparameterize the identity bn into wi and bi
if has_identity {
let identity_bn = batch_norm(dim, 1e-5, vb.pp("identity"))?;
let mut weights: Vec<f32> = vec![0.0; w.elem_count()];
let id = in_channels / groups;
// See https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/byobnet.py#L809
for i in 0..in_channels {
if kernel > 1 {
weights[i * kernel * kernel + 4] = 1.0;
} else {
weights[i * (id + 1)] = 1.0;
}
}
let weights = &Tensor::from_vec(weights, w.shape(), w.device())?;
let (wi, bi) = fuse_conv_bn(weights, identity_bn)?;
w = (w + wi)?;
b = (b + bi)?;
}
let reparam_conv = Conv2d::new(w, Some(b), conv2d_cfg);
Ok(Func::new(move |xs| {
let mut xs = xs.apply(&reparam_conv)?;
if let Ok(f) = &se {
xs = xs.apply(f)?;
}
xs = xs.relu()?;
Ok(xs)
}))
}
// Get the number of output channels per stage taking into account the multipliers
fn output_channels_per_stage(cfg: &Config, stage: usize) -> usize {
let channels = STAGES[stage].channels as f32;
let alpha = cfg.alphas[stage];
match stage {
0 => std::cmp::min(64, (channels * alpha) as usize),
_ => (channels * alpha) as usize,
}
}
// Each stage is made of blocks. The first layer always downsamples with stride 2.
// All but the first block have a residual connection.
fn mobileone_stage(cfg: &Config, idx: usize, vb: VarBuilder) -> Result<Func<'static>> {
let nblocks = STAGES[idx].blocks;
let mut blocks = Vec::with_capacity(nblocks);
let mut in_channels = output_channels_per_stage(cfg, idx - 1);
for block_idx in 0..nblocks {
let out_channels = output_channels_per_stage(cfg, idx);
let (has_identity, stride) = if block_idx == 0 {
(false, 2)
} else {
(true, 1)
};
// depthwise convolution layer
blocks.push(mobileone_block(
has_identity,
cfg.k,
in_channels,
stride,
1,
in_channels,
3,
in_channels,
in_channels,
vb.pp(block_idx * 2),
)?);
// pointwise convolution layer
blocks.push(mobileone_block(
has_identity,
cfg.k,
out_channels,
1, // stride
0, // padding
1, // groups
1, // kernel
in_channels,
out_channels,
vb.pp(block_idx * 2 + 1),
)?);
in_channels = out_channels;
}
Ok(Func::new(move |xs| {
let mut xs = xs.clone();
for block in blocks.iter() {
xs = xs.apply(block)?
}
Ok(xs)
}))
}
// Build a mobileone model for a given configuration.
fn mobileone_model(
config: &Config,
nclasses: Option<usize>,
vb: VarBuilder,
) -> Result<Func<'static>> {
let cls = match nclasses {
None => None,
Some(nclasses) => {
let outputs = output_channels_per_stage(config, 4);
let linear = linear(outputs, nclasses, vb.pp("head.fc"))?;
Some(linear)
}
};
let stem_dim = output_channels_per_stage(config, 0);
let stem = mobileone_block(false, 1, stem_dim, 2, 1, 1, 3, 3, stem_dim, vb.pp("stem"))?;
let vb = vb.pp("stages");
let stage1 = mobileone_stage(config, 1, vb.pp(0))?;
let stage2 = mobileone_stage(config, 2, vb.pp(1))?;
let stage3 = mobileone_stage(config, 3, vb.pp(2))?;
let stage4 = mobileone_stage(config, 4, vb.pp(3))?;
Ok(Func::new(move |xs| {
let xs = xs
.apply(&stem)?
.apply(&stage1)?
.apply(&stage2)?
.apply(&stage3)?
.apply(&stage4)?
.mean(D::Minus2)?
.mean(D::Minus1)?;
match &cls {
None => Ok(xs),
Some(cls) => xs.apply(cls),
}
}))
}
pub fn mobileone(cfg: &Config, nclasses: usize, vb: VarBuilder) -> Result<Func<'static>> {
mobileone_model(cfg, Some(nclasses), vb)
}
pub fn mobileone_no_final_layer(cfg: &Config, vb: VarBuilder) -> Result<Func<'static>> {
mobileone_model(cfg, None, vb)
}

View File

@ -15,6 +15,7 @@ pub mod marian;
pub mod mistral; pub mod mistral;
pub mod mixformer; pub mod mixformer;
pub mod mixtral; pub mod mixtral;
pub mod mobileone;
pub mod mpt; pub mod mpt;
pub mod persimmon; pub mod persimmon;
pub mod phi; pub mod phi;