Add ConvNeXt-V2 and smaller model variants. (#1709)

This commit is contained in:
Jani Monoses
2024-02-14 11:53:07 +02:00
committed by GitHub
parent b60064780d
commit 68f7655895
3 changed files with 214 additions and 51 deletions

View File

@ -1,6 +1,7 @@
# candle-convnext
[A ConvNet for the 2020s](https://arxiv.org/abs/2201.03545).
[A ConvNet for the 2020s](https://arxiv.org/abs/2201.03545) and
[ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders](https://arxiv.org/abs/2301.00808).
This candle implementation uses a pre-trained ConvNeXt network for inference. The
classification head has been trained on the ImageNet dataset and returns the

View File

@ -12,38 +12,62 @@ use candle_transformers::models::convnext;
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Which {
Atto,
Femto,
Pico,
Nano,
Tiny,
Small,
Base,
Large,
AttoV2,
FemtoV2,
PicoV2,
NanoV2,
TinyV2,
BaseV2,
LargeV2,
XLarge,
Huge,
}
impl Which {
fn model_filename(&self) -> String {
let name = match self {
Self::Tiny => "tiny",
Self::Small => "small",
Self::Base => "base",
Self::Large => "large",
Self::XLarge => "xlarge",
};
// The XLarge model only has an ImageNet-22K variant
let variant = match self {
Self::XLarge => "fb_in22k_ft_in1k",
_ => "fb_in1k",
Self::Atto => "convnext_atto.d2_in1k",
Self::Femto => "convnext_femto.d1_in1k",
Self::Pico => "convnext_pico.d1_in1k",
Self::Nano => "convnext_nano.d1h_in1k",
Self::Tiny => "convnext_tiny.fb_in1k",
Self::Small => "convnext_small.fb_in1k",
Self::Base => "convnext_base.fb_in1k",
Self::Large => "convnext_large.fb_in1k",
Self::AttoV2 => "convnextv2_atto.fcmae_ft_in1k",
Self::FemtoV2 => "convnextv2_femto.fcmae_ft_in1k",
Self::PicoV2 => "convnextv2_pico.fcmae_ft_in1k",
Self::NanoV2 => "convnextv2_nano.fcmae_ft_in1k",
Self::TinyV2 => "convnextv2_tiny.fcmae_ft_in1k",
Self::BaseV2 => "convnextv2_base.fcmae_ft_in1k",
Self::LargeV2 => "convnextv2_large.fcmae_ft_in1k",
Self::XLarge => "convnext_xlarge.fb_in22k_ft_in1k",
Self::Huge => "convnextv2_huge.fcmae_ft_in1k",
};
format!("timm/convnext_{name}.{variant}")
format!("timm/{name}")
}
fn config(&self) -> convnext::Config {
match self {
Self::Tiny => convnext::Config::tiny(),
Self::Atto | Self::AttoV2 => convnext::Config::atto(),
Self::Femto | Self::FemtoV2 => convnext::Config::femto(),
Self::Pico | Self::PicoV2 => convnext::Config::pico(),
Self::Nano | Self::NanoV2 => convnext::Config::nano(),
Self::Tiny | Self::TinyV2 => convnext::Config::tiny(),
Self::Small => convnext::Config::small(),
Self::Base => convnext::Config::base(),
Self::Large => convnext::Config::large(),
Self::Base | Self::BaseV2 => convnext::Config::base(),
Self::Large | Self::LargeV2 => convnext::Config::large(),
Self::XLarge => convnext::Config::xlarge(),
Self::Huge => convnext::Config::huge(),
}
}
}

View File

@ -2,10 +2,16 @@
//!
//! See "A ConvNet for the 2020s" Liu et al. 2022
//! <https://arxiv.org/abs/2201.03545>
//! and
//! "ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders" Woo et al. 2023
//! <https://arxiv.org/abs/2301.00808>
//! Original code: https://github.com/facebookresearch/ConvNeXt/
//! Original code:
//! https://github.com/facebookresearch/ConvNeXt/
//! https://github.com/facebookresearch/ConvNeXt-V2/
//! timm: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/convnext.py
use candle::shape::ShapeWithOneHole;
use candle::{Result, D};
use candle_nn::{conv2d, layer_norm, linear, Conv2dConfig, Func, VarBuilder};
@ -13,31 +19,71 @@ use candle_nn::{conv2d, layer_norm, linear, Conv2dConfig, Func, VarBuilder};
pub struct Config {
blocks: [usize; 4],
channels: [usize; 4],
use_conv_mlp: bool,
}
impl Config {
pub fn atto() -> Self {
Self {
blocks: [2, 2, 6, 2],
channels: [40, 80, 160, 320],
use_conv_mlp: true,
}
}
pub fn femto() -> Self {
Self {
blocks: [2, 2, 6, 2],
channels: [48, 96, 192, 384],
use_conv_mlp: true,
}
}
pub fn pico() -> Self {
Self {
blocks: [2, 2, 6, 2],
channels: [64, 128, 256, 512],
use_conv_mlp: true,
}
}
pub fn nano() -> Self {
Self {
blocks: [2, 2, 8, 2],
channels: [80, 160, 320, 640],
use_conv_mlp: true,
}
}
pub fn tiny() -> Self {
Self {
blocks: [3, 3, 9, 3],
channels: [96, 192, 384, 768],
use_conv_mlp: false,
}
}
pub fn small() -> Self {
Self {
blocks: [3, 3, 27, 3],
channels: [96, 192, 384, 768],
use_conv_mlp: false,
}
}
pub fn base() -> Self {
Self {
blocks: [3, 3, 27, 3],
channels: [128, 256, 512, 1024],
use_conv_mlp: false,
}
}
pub fn large() -> Self {
Self {
blocks: [3, 3, 27, 3],
channels: [192, 384, 768, 1536],
use_conv_mlp: false,
}
}
@ -45,8 +91,68 @@ impl Config {
Self {
blocks: [3, 3, 27, 3],
channels: [256, 512, 1024, 2048],
use_conv_mlp: false,
}
}
pub fn huge() -> Self {
Self {
blocks: [3, 3, 27, 3],
channels: [352, 704, 1408, 2816],
use_conv_mlp: false,
}
}
}
// Layer norm for data in channels-last format.
fn layer_norm_cl(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
let norm = layer_norm(dim, 1e-6, vb)?;
Ok(Func::new(move |xs| xs.apply(&norm)))
}
// Layer norm for data in channels-first format.
fn layer_norm_cf(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
let norm = layer_norm(dim, 1e-6, vb)?;
Ok(Func::new(move |xs| {
let xs = xs
.permute((0, 2, 3, 1))?
.apply(&norm)?
.permute((0, 3, 1, 2))?;
Ok(xs)
}))
}
// Global response normalization layer
// Based on https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/grn.py
fn convnext2_grn(dim: usize, channels_last: bool, vb: VarBuilder) -> Result<Func<'static>> {
let (shape, spatial_dim, channel_dim) = if channels_last {
((1, 1, 1, ()).into_shape(dim)?, [1, 2], 3)
} else {
((1, (), 1, 1).into_shape(dim)?, [2, 3], 1)
};
let gamma = vb.get(dim, "weight")?.reshape(&shape)?;
let beta = vb.get(dim, "bias")?.reshape(&shape)?;
Ok(Func::new(move |xs| {
let residual = xs;
let gx = xs
.sqr()?
.sum_keepdim(spatial_dim)?
.mean_keepdim(spatial_dim)?
.sqrt()?;
let gxmean = gx.mean_keepdim(channel_dim)?;
let nx = gx.broadcast_div(&(gxmean + 1e-6)?)?;
let xs = xs
.broadcast_mul(&nx)?
.broadcast_mul(&gamma)?
.broadcast_add(&beta)?;
xs + residual
}))
}
// Initial downsampling via a patchify layer.
@ -56,16 +162,9 @@ fn convnext_stem(out_channels: usize, vb: VarBuilder) -> Result<Func<'static>> {
..Default::default()
};
let patchify = conv2d(3, out_channels, 4, conv2d_cfg, vb.pp(0))?;
let norm = layer_norm(out_channels, 1e-6, vb.pp(1))?;
Ok(Func::new(move |xs| {
// The layer norm works with channels-last format.
let xs = xs
.apply(&patchify)?
.permute((0, 2, 3, 1))?
.apply(&norm)?
.permute((0, 3, 1, 2))?;
Ok(xs)
}))
let norm = layer_norm_cf(out_channels, vb.pp(1))?;
Ok(Func::new(move |xs| xs.apply(&patchify)?.apply(&norm)))
}
// Downsampling applied after the stages.
@ -74,31 +173,49 @@ fn convnext_downsample(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
stride: 2,
..Default::default()
};
let norm = layer_norm(dim / 2, 1e-5, vb.pp(0))?;
let norm = layer_norm_cf(dim / 2, vb.pp(0))?;
let conv = conv2d(dim / 2, dim, 2, conv2d_cfg, vb.pp(1))?;
Ok(Func::new(move |xs| {
let xs = xs
.permute((0, 2, 3, 1))?
.apply(&norm)?
.permute((0, 3, 1, 2))?
.apply(&conv)?;
Ok(xs)
}))
Ok(Func::new(move |xs| xs.apply(&norm)?.apply(&conv)))
}
// MLP equivalent of pointwise convolutions.
// MLP block from the original paper with optional GRN layer (v2 models).
fn convnext_mlp(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
let fc1 = linear(dim, 4 * dim, vb.pp("fc1"))?;
let fc2 = linear(4 * dim, dim, vb.pp("fc2"))?;
let grn = convnext2_grn(4 * dim, true, vb.pp("grn"));
Ok(Func::new(move |xs| {
let xs = xs.apply(&fc1)?.gelu_erf()?.apply(&fc2)?;
let mut xs = xs.apply(&fc1)?.gelu_erf()?;
if let Ok(g) = &grn {
xs = xs.apply(g)?;
}
xs = xs.apply(&fc2)?;
Ok(xs)
}))
}
// A block consisting of a depthwise convolution, a MLP and layer scaling.
fn convnext_block(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
// MLP block using pointwise convolutions, with optional GRN layer (v2 models).
fn convnext_conv_mlp(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
let conv2d_cfg = Conv2dConfig {
..Default::default()
};
let fc1 = conv2d(dim, 4 * dim, 1, conv2d_cfg, vb.pp("fc1"))?;
let fc2 = conv2d(4 * dim, dim, 1, conv2d_cfg, vb.pp("fc2"))?;
let grn = convnext2_grn(4 * dim, false, vb.pp("grn"));
Ok(Func::new(move |xs| {
let mut xs = xs.apply(&fc1)?.gelu_erf()?;
if let Ok(g) = &grn {
xs = xs.apply(g)?;
}
xs = xs.apply(&fc2)?;
Ok(xs)
}))
}
// A block consisting of a depthwise convolution, a MLP and layer scaling (v1 models only).
fn convnext_block(dim: usize, use_conv_mlp: bool, vb: VarBuilder) -> Result<Func<'static>> {
let conv2d_cfg = Conv2dConfig {
groups: dim,
padding: 3,
@ -106,20 +223,36 @@ fn convnext_block(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
};
let conv_dw = conv2d(dim, dim, 7, conv2d_cfg, vb.pp("conv_dw"))?;
let gamma = vb.get(dim, "gamma");
let gamma = vb.get(dim, "gamma")?;
let mlp = convnext_mlp(dim, vb.pp("mlp"))?;
let norm = layer_norm(dim, 1e-6, vb.pp("norm"))?;
let (mlp, norm) = if use_conv_mlp {
(
convnext_conv_mlp(dim, vb.pp("mlp"))?,
layer_norm_cf(dim, vb.pp("norm"))?,
)
} else {
(
convnext_mlp(dim, vb.pp("mlp"))?,
layer_norm_cl(dim, vb.pp("norm"))?,
)
};
Ok(Func::new(move |xs| {
let residual = xs;
let xs = xs
.apply(&conv_dw)?
.permute((0, 2, 3, 1))?
.apply(&norm)?
.apply(&mlp)?
.broadcast_mul(&gamma)?
.permute((0, 3, 1, 2))?;
let mut xs = xs.apply(&conv_dw)?;
xs = if use_conv_mlp {
xs.apply(&norm)?.apply(&mlp)?
} else {
xs.permute((0, 2, 3, 1))?
.apply(&norm)?
.apply(&mlp)?
.permute((0, 3, 1, 2))?
};
if let Ok(g) = &gamma {
xs = xs.broadcast_mul(&g.reshape((1, (), 1, 1))?)?;
};
xs + residual
}))
@ -137,7 +270,11 @@ fn convnext_stage(cfg: &Config, stage_idx: usize, vb: VarBuilder) -> Result<Func
}
for block_idx in 0..nblocks {
blocks.push(convnext_block(dim, vb.pp(format!("blocks.{block_idx}")))?);
blocks.push(convnext_block(
dim,
cfg.use_conv_mlp,
vb.pp(format!("blocks.{block_idx}")),
)?);
}
Ok(Func::new(move |xs| {
@ -149,8 +286,9 @@ fn convnext_stage(cfg: &Config, stage_idx: usize, vb: VarBuilder) -> Result<Func
}))
}
// Classification head.
fn convnext_head(outputs: usize, nclasses: usize, vb: VarBuilder) -> Result<Func<'static>> {
let norm = layer_norm(outputs, 1e-6, vb.pp("norm"))?;
let norm = layer_norm_cl(outputs, vb.pp("norm"))?;
let linear = linear(outputs, nclasses, vb.pp("fc"))?;
Ok(Func::new(move |xs| xs.apply(&norm)?.apply(&linear)))
}