Add ConvNeXt-V2 and smaller model variants. (#1709)

This commit is contained in:
Jani Monoses
2024-02-14 11:53:07 +02:00
committed by GitHub
parent b60064780d
commit 68f7655895
3 changed files with 214 additions and 51 deletions

View File

@ -1,6 +1,7 @@
# candle-convnext # candle-convnext
[A ConvNet for the 2020s](https://arxiv.org/abs/2201.03545). [A ConvNet for the 2020s](https://arxiv.org/abs/2201.03545) and
[ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders](https://arxiv.org/abs/2301.00808).
This candle implementation uses a pre-trained ConvNeXt network for inference. The This candle implementation uses a pre-trained ConvNeXt network for inference. The
classification head has been trained on the ImageNet dataset and returns the classification head has been trained on the ImageNet dataset and returns the

View File

@ -12,38 +12,62 @@ use candle_transformers::models::convnext;
#[derive(Clone, Copy, Debug, ValueEnum)] #[derive(Clone, Copy, Debug, ValueEnum)]
enum Which { enum Which {
Atto,
Femto,
Pico,
Nano,
Tiny, Tiny,
Small, Small,
Base, Base,
Large, Large,
AttoV2,
FemtoV2,
PicoV2,
NanoV2,
TinyV2,
BaseV2,
LargeV2,
XLarge, XLarge,
Huge,
} }
impl Which { impl Which {
fn model_filename(&self) -> String { fn model_filename(&self) -> String {
let name = match self { let name = match self {
Self::Tiny => "tiny", Self::Atto => "convnext_atto.d2_in1k",
Self::Small => "small", Self::Femto => "convnext_femto.d1_in1k",
Self::Base => "base", Self::Pico => "convnext_pico.d1_in1k",
Self::Large => "large", Self::Nano => "convnext_nano.d1h_in1k",
Self::XLarge => "xlarge", Self::Tiny => "convnext_tiny.fb_in1k",
}; Self::Small => "convnext_small.fb_in1k",
// The XLarge model only has an ImageNet-22K variant Self::Base => "convnext_base.fb_in1k",
let variant = match self { Self::Large => "convnext_large.fb_in1k",
Self::XLarge => "fb_in22k_ft_in1k", Self::AttoV2 => "convnextv2_atto.fcmae_ft_in1k",
_ => "fb_in1k", Self::FemtoV2 => "convnextv2_femto.fcmae_ft_in1k",
Self::PicoV2 => "convnextv2_pico.fcmae_ft_in1k",
Self::NanoV2 => "convnextv2_nano.fcmae_ft_in1k",
Self::TinyV2 => "convnextv2_tiny.fcmae_ft_in1k",
Self::BaseV2 => "convnextv2_base.fcmae_ft_in1k",
Self::LargeV2 => "convnextv2_large.fcmae_ft_in1k",
Self::XLarge => "convnext_xlarge.fb_in22k_ft_in1k",
Self::Huge => "convnextv2_huge.fcmae_ft_in1k",
}; };
format!("timm/convnext_{name}.{variant}") format!("timm/{name}")
} }
fn config(&self) -> convnext::Config { fn config(&self) -> convnext::Config {
match self { match self {
Self::Tiny => convnext::Config::tiny(), Self::Atto | Self::AttoV2 => convnext::Config::atto(),
Self::Femto | Self::FemtoV2 => convnext::Config::femto(),
Self::Pico | Self::PicoV2 => convnext::Config::pico(),
Self::Nano | Self::NanoV2 => convnext::Config::nano(),
Self::Tiny | Self::TinyV2 => convnext::Config::tiny(),
Self::Small => convnext::Config::small(), Self::Small => convnext::Config::small(),
Self::Base => convnext::Config::base(), Self::Base | Self::BaseV2 => convnext::Config::base(),
Self::Large => convnext::Config::large(), Self::Large | Self::LargeV2 => convnext::Config::large(),
Self::XLarge => convnext::Config::xlarge(), Self::XLarge => convnext::Config::xlarge(),
Self::Huge => convnext::Config::huge(),
} }
} }
} }

View File

@ -2,10 +2,16 @@
//! //!
//! See "A ConvNet for the 2020s" Liu et al. 2022 //! See "A ConvNet for the 2020s" Liu et al. 2022
//! <https://arxiv.org/abs/2201.03545> //! <https://arxiv.org/abs/2201.03545>
//! and
//! "ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders" Woo et al. 2023
//! <https://arxiv.org/abs/2301.00808>
//! Original code: https://github.com/facebookresearch/ConvNeXt/ //! Original code:
//! https://github.com/facebookresearch/ConvNeXt/
//! https://github.com/facebookresearch/ConvNeXt-V2/
//! timm: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/convnext.py //! timm: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/convnext.py
use candle::shape::ShapeWithOneHole;
use candle::{Result, D}; use candle::{Result, D};
use candle_nn::{conv2d, layer_norm, linear, Conv2dConfig, Func, VarBuilder}; use candle_nn::{conv2d, layer_norm, linear, Conv2dConfig, Func, VarBuilder};
@ -13,31 +19,71 @@ use candle_nn::{conv2d, layer_norm, linear, Conv2dConfig, Func, VarBuilder};
pub struct Config { pub struct Config {
blocks: [usize; 4], blocks: [usize; 4],
channels: [usize; 4], channels: [usize; 4],
use_conv_mlp: bool,
} }
impl Config { impl Config {
pub fn atto() -> Self {
Self {
blocks: [2, 2, 6, 2],
channels: [40, 80, 160, 320],
use_conv_mlp: true,
}
}
pub fn femto() -> Self {
Self {
blocks: [2, 2, 6, 2],
channels: [48, 96, 192, 384],
use_conv_mlp: true,
}
}
pub fn pico() -> Self {
Self {
blocks: [2, 2, 6, 2],
channels: [64, 128, 256, 512],
use_conv_mlp: true,
}
}
pub fn nano() -> Self {
Self {
blocks: [2, 2, 8, 2],
channels: [80, 160, 320, 640],
use_conv_mlp: true,
}
}
pub fn tiny() -> Self { pub fn tiny() -> Self {
Self { Self {
blocks: [3, 3, 9, 3], blocks: [3, 3, 9, 3],
channels: [96, 192, 384, 768], channels: [96, 192, 384, 768],
use_conv_mlp: false,
} }
} }
pub fn small() -> Self { pub fn small() -> Self {
Self { Self {
blocks: [3, 3, 27, 3], blocks: [3, 3, 27, 3],
channels: [96, 192, 384, 768], channels: [96, 192, 384, 768],
use_conv_mlp: false,
} }
} }
pub fn base() -> Self { pub fn base() -> Self {
Self { Self {
blocks: [3, 3, 27, 3], blocks: [3, 3, 27, 3],
channels: [128, 256, 512, 1024], channels: [128, 256, 512, 1024],
use_conv_mlp: false,
} }
} }
pub fn large() -> Self { pub fn large() -> Self {
Self { Self {
blocks: [3, 3, 27, 3], blocks: [3, 3, 27, 3],
channels: [192, 384, 768, 1536], channels: [192, 384, 768, 1536],
use_conv_mlp: false,
} }
} }
@ -45,8 +91,68 @@ impl Config {
Self { Self {
blocks: [3, 3, 27, 3], blocks: [3, 3, 27, 3],
channels: [256, 512, 1024, 2048], channels: [256, 512, 1024, 2048],
use_conv_mlp: false,
} }
} }
pub fn huge() -> Self {
Self {
blocks: [3, 3, 27, 3],
channels: [352, 704, 1408, 2816],
use_conv_mlp: false,
}
}
}
// Layer norm for data in channels-last format.
fn layer_norm_cl(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
let norm = layer_norm(dim, 1e-6, vb)?;
Ok(Func::new(move |xs| xs.apply(&norm)))
}
// Layer norm for data in channels-first format.
fn layer_norm_cf(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
let norm = layer_norm(dim, 1e-6, vb)?;
Ok(Func::new(move |xs| {
let xs = xs
.permute((0, 2, 3, 1))?
.apply(&norm)?
.permute((0, 3, 1, 2))?;
Ok(xs)
}))
}
// Global response normalization layer
// Based on https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/grn.py
fn convnext2_grn(dim: usize, channels_last: bool, vb: VarBuilder) -> Result<Func<'static>> {
let (shape, spatial_dim, channel_dim) = if channels_last {
((1, 1, 1, ()).into_shape(dim)?, [1, 2], 3)
} else {
((1, (), 1, 1).into_shape(dim)?, [2, 3], 1)
};
let gamma = vb.get(dim, "weight")?.reshape(&shape)?;
let beta = vb.get(dim, "bias")?.reshape(&shape)?;
Ok(Func::new(move |xs| {
let residual = xs;
let gx = xs
.sqr()?
.sum_keepdim(spatial_dim)?
.mean_keepdim(spatial_dim)?
.sqrt()?;
let gxmean = gx.mean_keepdim(channel_dim)?;
let nx = gx.broadcast_div(&(gxmean + 1e-6)?)?;
let xs = xs
.broadcast_mul(&nx)?
.broadcast_mul(&gamma)?
.broadcast_add(&beta)?;
xs + residual
}))
} }
// Initial downsampling via a patchify layer. // Initial downsampling via a patchify layer.
@ -56,16 +162,9 @@ fn convnext_stem(out_channels: usize, vb: VarBuilder) -> Result<Func<'static>> {
..Default::default() ..Default::default()
}; };
let patchify = conv2d(3, out_channels, 4, conv2d_cfg, vb.pp(0))?; let patchify = conv2d(3, out_channels, 4, conv2d_cfg, vb.pp(0))?;
let norm = layer_norm(out_channels, 1e-6, vb.pp(1))?; let norm = layer_norm_cf(out_channels, vb.pp(1))?;
Ok(Func::new(move |xs| {
// The layer norm works with channels-last format. Ok(Func::new(move |xs| xs.apply(&patchify)?.apply(&norm)))
let xs = xs
.apply(&patchify)?
.permute((0, 2, 3, 1))?
.apply(&norm)?
.permute((0, 3, 1, 2))?;
Ok(xs)
}))
} }
// Downsampling applied after the stages. // Downsampling applied after the stages.
@ -74,31 +173,49 @@ fn convnext_downsample(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
stride: 2, stride: 2,
..Default::default() ..Default::default()
}; };
let norm = layer_norm(dim / 2, 1e-5, vb.pp(0))?; let norm = layer_norm_cf(dim / 2, vb.pp(0))?;
let conv = conv2d(dim / 2, dim, 2, conv2d_cfg, vb.pp(1))?; let conv = conv2d(dim / 2, dim, 2, conv2d_cfg, vb.pp(1))?;
Ok(Func::new(move |xs| {
let xs = xs Ok(Func::new(move |xs| xs.apply(&norm)?.apply(&conv)))
.permute((0, 2, 3, 1))?
.apply(&norm)?
.permute((0, 3, 1, 2))?
.apply(&conv)?;
Ok(xs)
}))
} }
// MLP equivalent of pointwise convolutions. // MLP block from the original paper with optional GRN layer (v2 models).
fn convnext_mlp(dim: usize, vb: VarBuilder) -> Result<Func<'static>> { fn convnext_mlp(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
let fc1 = linear(dim, 4 * dim, vb.pp("fc1"))?; let fc1 = linear(dim, 4 * dim, vb.pp("fc1"))?;
let fc2 = linear(4 * dim, dim, vb.pp("fc2"))?; let fc2 = linear(4 * dim, dim, vb.pp("fc2"))?;
let grn = convnext2_grn(4 * dim, true, vb.pp("grn"));
Ok(Func::new(move |xs| { Ok(Func::new(move |xs| {
let xs = xs.apply(&fc1)?.gelu_erf()?.apply(&fc2)?; let mut xs = xs.apply(&fc1)?.gelu_erf()?;
if let Ok(g) = &grn {
xs = xs.apply(g)?;
}
xs = xs.apply(&fc2)?;
Ok(xs) Ok(xs)
})) }))
} }
// A block consisting of a depthwise convolution, a MLP and layer scaling. // MLP block using pointwise convolutions, with optional GRN layer (v2 models).
fn convnext_block(dim: usize, vb: VarBuilder) -> Result<Func<'static>> { fn convnext_conv_mlp(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
let conv2d_cfg = Conv2dConfig {
..Default::default()
};
let fc1 = conv2d(dim, 4 * dim, 1, conv2d_cfg, vb.pp("fc1"))?;
let fc2 = conv2d(4 * dim, dim, 1, conv2d_cfg, vb.pp("fc2"))?;
let grn = convnext2_grn(4 * dim, false, vb.pp("grn"));
Ok(Func::new(move |xs| {
let mut xs = xs.apply(&fc1)?.gelu_erf()?;
if let Ok(g) = &grn {
xs = xs.apply(g)?;
}
xs = xs.apply(&fc2)?;
Ok(xs)
}))
}
// A block consisting of a depthwise convolution, a MLP and layer scaling (v1 models only).
fn convnext_block(dim: usize, use_conv_mlp: bool, vb: VarBuilder) -> Result<Func<'static>> {
let conv2d_cfg = Conv2dConfig { let conv2d_cfg = Conv2dConfig {
groups: dim, groups: dim,
padding: 3, padding: 3,
@ -106,20 +223,36 @@ fn convnext_block(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
}; };
let conv_dw = conv2d(dim, dim, 7, conv2d_cfg, vb.pp("conv_dw"))?; let conv_dw = conv2d(dim, dim, 7, conv2d_cfg, vb.pp("conv_dw"))?;
let gamma = vb.get(dim, "gamma");
let gamma = vb.get(dim, "gamma")?; let (mlp, norm) = if use_conv_mlp {
let mlp = convnext_mlp(dim, vb.pp("mlp"))?; (
let norm = layer_norm(dim, 1e-6, vb.pp("norm"))?; convnext_conv_mlp(dim, vb.pp("mlp"))?,
layer_norm_cf(dim, vb.pp("norm"))?,
)
} else {
(
convnext_mlp(dim, vb.pp("mlp"))?,
layer_norm_cl(dim, vb.pp("norm"))?,
)
};
Ok(Func::new(move |xs| { Ok(Func::new(move |xs| {
let residual = xs; let residual = xs;
let xs = xs let mut xs = xs.apply(&conv_dw)?;
.apply(&conv_dw)?
.permute((0, 2, 3, 1))? xs = if use_conv_mlp {
.apply(&norm)? xs.apply(&norm)?.apply(&mlp)?
.apply(&mlp)? } else {
.broadcast_mul(&gamma)? xs.permute((0, 2, 3, 1))?
.permute((0, 3, 1, 2))?; .apply(&norm)?
.apply(&mlp)?
.permute((0, 3, 1, 2))?
};
if let Ok(g) = &gamma {
xs = xs.broadcast_mul(&g.reshape((1, (), 1, 1))?)?;
};
xs + residual xs + residual
})) }))
@ -137,7 +270,11 @@ fn convnext_stage(cfg: &Config, stage_idx: usize, vb: VarBuilder) -> Result<Func
} }
for block_idx in 0..nblocks { for block_idx in 0..nblocks {
blocks.push(convnext_block(dim, vb.pp(format!("blocks.{block_idx}")))?); blocks.push(convnext_block(
dim,
cfg.use_conv_mlp,
vb.pp(format!("blocks.{block_idx}")),
)?);
} }
Ok(Func::new(move |xs| { Ok(Func::new(move |xs| {
@ -149,8 +286,9 @@ fn convnext_stage(cfg: &Config, stage_idx: usize, vb: VarBuilder) -> Result<Func
})) }))
} }
// Classification head.
fn convnext_head(outputs: usize, nclasses: usize, vb: VarBuilder) -> Result<Func<'static>> { fn convnext_head(outputs: usize, nclasses: usize, vb: VarBuilder) -> Result<Func<'static>> {
let norm = layer_norm(outputs, 1e-6, vb.pp("norm"))?; let norm = layer_norm_cl(outputs, vb.pp("norm"))?;
let linear = linear(outputs, nclasses, vb.pp("fc"))?; let linear = linear(outputs, nclasses, vb.pp("fc"))?;
Ok(Func::new(move |xs| xs.apply(&norm)?.apply(&linear))) Ok(Func::new(move |xs| xs.apply(&norm)?.apply(&linear)))
} }