mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Add ConvNeXt-V2 and smaller model variants. (#1709)
This commit is contained in:
@ -1,6 +1,7 @@
|
||||
# candle-convnext
|
||||
|
||||
[A ConvNet for the 2020s](https://arxiv.org/abs/2201.03545).
|
||||
[A ConvNet for the 2020s](https://arxiv.org/abs/2201.03545) and
|
||||
[ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders](https://arxiv.org/abs/2301.00808).
|
||||
|
||||
This candle implementation uses a pre-trained ConvNeXt network for inference. The
|
||||
classification head has been trained on the ImageNet dataset and returns the
|
||||
|
@ -12,38 +12,62 @@ use candle_transformers::models::convnext;
|
||||
|
||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||
enum Which {
|
||||
Atto,
|
||||
Femto,
|
||||
Pico,
|
||||
Nano,
|
||||
Tiny,
|
||||
Small,
|
||||
Base,
|
||||
Large,
|
||||
AttoV2,
|
||||
FemtoV2,
|
||||
PicoV2,
|
||||
NanoV2,
|
||||
TinyV2,
|
||||
BaseV2,
|
||||
LargeV2,
|
||||
XLarge,
|
||||
Huge,
|
||||
}
|
||||
|
||||
impl Which {
|
||||
fn model_filename(&self) -> String {
|
||||
let name = match self {
|
||||
Self::Tiny => "tiny",
|
||||
Self::Small => "small",
|
||||
Self::Base => "base",
|
||||
Self::Large => "large",
|
||||
Self::XLarge => "xlarge",
|
||||
};
|
||||
// The XLarge model only has an ImageNet-22K variant
|
||||
let variant = match self {
|
||||
Self::XLarge => "fb_in22k_ft_in1k",
|
||||
_ => "fb_in1k",
|
||||
Self::Atto => "convnext_atto.d2_in1k",
|
||||
Self::Femto => "convnext_femto.d1_in1k",
|
||||
Self::Pico => "convnext_pico.d1_in1k",
|
||||
Self::Nano => "convnext_nano.d1h_in1k",
|
||||
Self::Tiny => "convnext_tiny.fb_in1k",
|
||||
Self::Small => "convnext_small.fb_in1k",
|
||||
Self::Base => "convnext_base.fb_in1k",
|
||||
Self::Large => "convnext_large.fb_in1k",
|
||||
Self::AttoV2 => "convnextv2_atto.fcmae_ft_in1k",
|
||||
Self::FemtoV2 => "convnextv2_femto.fcmae_ft_in1k",
|
||||
Self::PicoV2 => "convnextv2_pico.fcmae_ft_in1k",
|
||||
Self::NanoV2 => "convnextv2_nano.fcmae_ft_in1k",
|
||||
Self::TinyV2 => "convnextv2_tiny.fcmae_ft_in1k",
|
||||
Self::BaseV2 => "convnextv2_base.fcmae_ft_in1k",
|
||||
Self::LargeV2 => "convnextv2_large.fcmae_ft_in1k",
|
||||
Self::XLarge => "convnext_xlarge.fb_in22k_ft_in1k",
|
||||
Self::Huge => "convnextv2_huge.fcmae_ft_in1k",
|
||||
};
|
||||
|
||||
format!("timm/convnext_{name}.{variant}")
|
||||
format!("timm/{name}")
|
||||
}
|
||||
|
||||
fn config(&self) -> convnext::Config {
|
||||
match self {
|
||||
Self::Tiny => convnext::Config::tiny(),
|
||||
Self::Atto | Self::AttoV2 => convnext::Config::atto(),
|
||||
Self::Femto | Self::FemtoV2 => convnext::Config::femto(),
|
||||
Self::Pico | Self::PicoV2 => convnext::Config::pico(),
|
||||
Self::Nano | Self::NanoV2 => convnext::Config::nano(),
|
||||
Self::Tiny | Self::TinyV2 => convnext::Config::tiny(),
|
||||
Self::Small => convnext::Config::small(),
|
||||
Self::Base => convnext::Config::base(),
|
||||
Self::Large => convnext::Config::large(),
|
||||
Self::Base | Self::BaseV2 => convnext::Config::base(),
|
||||
Self::Large | Self::LargeV2 => convnext::Config::large(),
|
||||
Self::XLarge => convnext::Config::xlarge(),
|
||||
Self::Huge => convnext::Config::huge(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -2,10 +2,16 @@
|
||||
//!
|
||||
//! See "A ConvNet for the 2020s" Liu et al. 2022
|
||||
//! <https://arxiv.org/abs/2201.03545>
|
||||
//! and
|
||||
//! "ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders" Woo et al. 2023
|
||||
//! <https://arxiv.org/abs/2301.00808>
|
||||
|
||||
//! Original code: https://github.com/facebookresearch/ConvNeXt/
|
||||
//! Original code:
|
||||
//! https://github.com/facebookresearch/ConvNeXt/
|
||||
//! https://github.com/facebookresearch/ConvNeXt-V2/
|
||||
//! timm: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/convnext.py
|
||||
|
||||
use candle::shape::ShapeWithOneHole;
|
||||
use candle::{Result, D};
|
||||
use candle_nn::{conv2d, layer_norm, linear, Conv2dConfig, Func, VarBuilder};
|
||||
|
||||
@ -13,31 +19,71 @@ use candle_nn::{conv2d, layer_norm, linear, Conv2dConfig, Func, VarBuilder};
|
||||
pub struct Config {
|
||||
blocks: [usize; 4],
|
||||
channels: [usize; 4],
|
||||
use_conv_mlp: bool,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn atto() -> Self {
|
||||
Self {
|
||||
blocks: [2, 2, 6, 2],
|
||||
channels: [40, 80, 160, 320],
|
||||
use_conv_mlp: true,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn femto() -> Self {
|
||||
Self {
|
||||
blocks: [2, 2, 6, 2],
|
||||
channels: [48, 96, 192, 384],
|
||||
use_conv_mlp: true,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn pico() -> Self {
|
||||
Self {
|
||||
blocks: [2, 2, 6, 2],
|
||||
channels: [64, 128, 256, 512],
|
||||
use_conv_mlp: true,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn nano() -> Self {
|
||||
Self {
|
||||
blocks: [2, 2, 8, 2],
|
||||
channels: [80, 160, 320, 640],
|
||||
use_conv_mlp: true,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tiny() -> Self {
|
||||
Self {
|
||||
blocks: [3, 3, 9, 3],
|
||||
channels: [96, 192, 384, 768],
|
||||
use_conv_mlp: false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn small() -> Self {
|
||||
Self {
|
||||
blocks: [3, 3, 27, 3],
|
||||
channels: [96, 192, 384, 768],
|
||||
use_conv_mlp: false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn base() -> Self {
|
||||
Self {
|
||||
blocks: [3, 3, 27, 3],
|
||||
channels: [128, 256, 512, 1024],
|
||||
use_conv_mlp: false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn large() -> Self {
|
||||
Self {
|
||||
blocks: [3, 3, 27, 3],
|
||||
channels: [192, 384, 768, 1536],
|
||||
use_conv_mlp: false,
|
||||
}
|
||||
}
|
||||
|
||||
@ -45,8 +91,68 @@ impl Config {
|
||||
Self {
|
||||
blocks: [3, 3, 27, 3],
|
||||
channels: [256, 512, 1024, 2048],
|
||||
use_conv_mlp: false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn huge() -> Self {
|
||||
Self {
|
||||
blocks: [3, 3, 27, 3],
|
||||
channels: [352, 704, 1408, 2816],
|
||||
use_conv_mlp: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Layer norm for data in channels-last format.
|
||||
fn layer_norm_cl(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
let norm = layer_norm(dim, 1e-6, vb)?;
|
||||
|
||||
Ok(Func::new(move |xs| xs.apply(&norm)))
|
||||
}
|
||||
|
||||
// Layer norm for data in channels-first format.
|
||||
fn layer_norm_cf(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
let norm = layer_norm(dim, 1e-6, vb)?;
|
||||
|
||||
Ok(Func::new(move |xs| {
|
||||
let xs = xs
|
||||
.permute((0, 2, 3, 1))?
|
||||
.apply(&norm)?
|
||||
.permute((0, 3, 1, 2))?;
|
||||
Ok(xs)
|
||||
}))
|
||||
}
|
||||
|
||||
// Global response normalization layer
|
||||
// Based on https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/grn.py
|
||||
fn convnext2_grn(dim: usize, channels_last: bool, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
let (shape, spatial_dim, channel_dim) = if channels_last {
|
||||
((1, 1, 1, ()).into_shape(dim)?, [1, 2], 3)
|
||||
} else {
|
||||
((1, (), 1, 1).into_shape(dim)?, [2, 3], 1)
|
||||
};
|
||||
|
||||
let gamma = vb.get(dim, "weight")?.reshape(&shape)?;
|
||||
let beta = vb.get(dim, "bias")?.reshape(&shape)?;
|
||||
|
||||
Ok(Func::new(move |xs| {
|
||||
let residual = xs;
|
||||
let gx = xs
|
||||
.sqr()?
|
||||
.sum_keepdim(spatial_dim)?
|
||||
.mean_keepdim(spatial_dim)?
|
||||
.sqrt()?;
|
||||
|
||||
let gxmean = gx.mean_keepdim(channel_dim)?;
|
||||
let nx = gx.broadcast_div(&(gxmean + 1e-6)?)?;
|
||||
let xs = xs
|
||||
.broadcast_mul(&nx)?
|
||||
.broadcast_mul(&gamma)?
|
||||
.broadcast_add(&beta)?;
|
||||
|
||||
xs + residual
|
||||
}))
|
||||
}
|
||||
|
||||
// Initial downsampling via a patchify layer.
|
||||
@ -56,16 +162,9 @@ fn convnext_stem(out_channels: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
..Default::default()
|
||||
};
|
||||
let patchify = conv2d(3, out_channels, 4, conv2d_cfg, vb.pp(0))?;
|
||||
let norm = layer_norm(out_channels, 1e-6, vb.pp(1))?;
|
||||
Ok(Func::new(move |xs| {
|
||||
// The layer norm works with channels-last format.
|
||||
let xs = xs
|
||||
.apply(&patchify)?
|
||||
.permute((0, 2, 3, 1))?
|
||||
.apply(&norm)?
|
||||
.permute((0, 3, 1, 2))?;
|
||||
Ok(xs)
|
||||
}))
|
||||
let norm = layer_norm_cf(out_channels, vb.pp(1))?;
|
||||
|
||||
Ok(Func::new(move |xs| xs.apply(&patchify)?.apply(&norm)))
|
||||
}
|
||||
|
||||
// Downsampling applied after the stages.
|
||||
@ -74,31 +173,49 @@ fn convnext_downsample(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
stride: 2,
|
||||
..Default::default()
|
||||
};
|
||||
let norm = layer_norm(dim / 2, 1e-5, vb.pp(0))?;
|
||||
let norm = layer_norm_cf(dim / 2, vb.pp(0))?;
|
||||
let conv = conv2d(dim / 2, dim, 2, conv2d_cfg, vb.pp(1))?;
|
||||
Ok(Func::new(move |xs| {
|
||||
let xs = xs
|
||||
.permute((0, 2, 3, 1))?
|
||||
.apply(&norm)?
|
||||
.permute((0, 3, 1, 2))?
|
||||
.apply(&conv)?;
|
||||
Ok(xs)
|
||||
}))
|
||||
|
||||
Ok(Func::new(move |xs| xs.apply(&norm)?.apply(&conv)))
|
||||
}
|
||||
|
||||
// MLP equivalent of pointwise convolutions.
|
||||
// MLP block from the original paper with optional GRN layer (v2 models).
|
||||
fn convnext_mlp(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
let fc1 = linear(dim, 4 * dim, vb.pp("fc1"))?;
|
||||
let fc2 = linear(4 * dim, dim, vb.pp("fc2"))?;
|
||||
let grn = convnext2_grn(4 * dim, true, vb.pp("grn"));
|
||||
|
||||
Ok(Func::new(move |xs| {
|
||||
let xs = xs.apply(&fc1)?.gelu_erf()?.apply(&fc2)?;
|
||||
let mut xs = xs.apply(&fc1)?.gelu_erf()?;
|
||||
if let Ok(g) = &grn {
|
||||
xs = xs.apply(g)?;
|
||||
}
|
||||
xs = xs.apply(&fc2)?;
|
||||
Ok(xs)
|
||||
}))
|
||||
}
|
||||
|
||||
// A block consisting of a depthwise convolution, a MLP and layer scaling.
|
||||
fn convnext_block(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
// MLP block using pointwise convolutions, with optional GRN layer (v2 models).
|
||||
fn convnext_conv_mlp(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
let conv2d_cfg = Conv2dConfig {
|
||||
..Default::default()
|
||||
};
|
||||
let fc1 = conv2d(dim, 4 * dim, 1, conv2d_cfg, vb.pp("fc1"))?;
|
||||
let fc2 = conv2d(4 * dim, dim, 1, conv2d_cfg, vb.pp("fc2"))?;
|
||||
|
||||
let grn = convnext2_grn(4 * dim, false, vb.pp("grn"));
|
||||
Ok(Func::new(move |xs| {
|
||||
let mut xs = xs.apply(&fc1)?.gelu_erf()?;
|
||||
if let Ok(g) = &grn {
|
||||
xs = xs.apply(g)?;
|
||||
}
|
||||
xs = xs.apply(&fc2)?;
|
||||
Ok(xs)
|
||||
}))
|
||||
}
|
||||
|
||||
// A block consisting of a depthwise convolution, a MLP and layer scaling (v1 models only).
|
||||
fn convnext_block(dim: usize, use_conv_mlp: bool, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
let conv2d_cfg = Conv2dConfig {
|
||||
groups: dim,
|
||||
padding: 3,
|
||||
@ -106,20 +223,36 @@ fn convnext_block(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
};
|
||||
|
||||
let conv_dw = conv2d(dim, dim, 7, conv2d_cfg, vb.pp("conv_dw"))?;
|
||||
let gamma = vb.get(dim, "gamma");
|
||||
|
||||
let gamma = vb.get(dim, "gamma")?;
|
||||
let mlp = convnext_mlp(dim, vb.pp("mlp"))?;
|
||||
let norm = layer_norm(dim, 1e-6, vb.pp("norm"))?;
|
||||
let (mlp, norm) = if use_conv_mlp {
|
||||
(
|
||||
convnext_conv_mlp(dim, vb.pp("mlp"))?,
|
||||
layer_norm_cf(dim, vb.pp("norm"))?,
|
||||
)
|
||||
} else {
|
||||
(
|
||||
convnext_mlp(dim, vb.pp("mlp"))?,
|
||||
layer_norm_cl(dim, vb.pp("norm"))?,
|
||||
)
|
||||
};
|
||||
|
||||
Ok(Func::new(move |xs| {
|
||||
let residual = xs;
|
||||
let xs = xs
|
||||
.apply(&conv_dw)?
|
||||
.permute((0, 2, 3, 1))?
|
||||
.apply(&norm)?
|
||||
.apply(&mlp)?
|
||||
.broadcast_mul(&gamma)?
|
||||
.permute((0, 3, 1, 2))?;
|
||||
let mut xs = xs.apply(&conv_dw)?;
|
||||
|
||||
xs = if use_conv_mlp {
|
||||
xs.apply(&norm)?.apply(&mlp)?
|
||||
} else {
|
||||
xs.permute((0, 2, 3, 1))?
|
||||
.apply(&norm)?
|
||||
.apply(&mlp)?
|
||||
.permute((0, 3, 1, 2))?
|
||||
};
|
||||
|
||||
if let Ok(g) = &gamma {
|
||||
xs = xs.broadcast_mul(&g.reshape((1, (), 1, 1))?)?;
|
||||
};
|
||||
|
||||
xs + residual
|
||||
}))
|
||||
@ -137,7 +270,11 @@ fn convnext_stage(cfg: &Config, stage_idx: usize, vb: VarBuilder) -> Result<Func
|
||||
}
|
||||
|
||||
for block_idx in 0..nblocks {
|
||||
blocks.push(convnext_block(dim, vb.pp(format!("blocks.{block_idx}")))?);
|
||||
blocks.push(convnext_block(
|
||||
dim,
|
||||
cfg.use_conv_mlp,
|
||||
vb.pp(format!("blocks.{block_idx}")),
|
||||
)?);
|
||||
}
|
||||
|
||||
Ok(Func::new(move |xs| {
|
||||
@ -149,8 +286,9 @@ fn convnext_stage(cfg: &Config, stage_idx: usize, vb: VarBuilder) -> Result<Func
|
||||
}))
|
||||
}
|
||||
|
||||
// Classification head.
|
||||
fn convnext_head(outputs: usize, nclasses: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
let norm = layer_norm(outputs, 1e-6, vb.pp("norm"))?;
|
||||
let norm = layer_norm_cl(outputs, vb.pp("norm"))?;
|
||||
let linear = linear(outputs, nclasses, vb.pp("fc"))?;
|
||||
Ok(Func::new(move |xs| xs.apply(&norm)?.apply(&linear)))
|
||||
}
|
||||
|
Reference in New Issue
Block a user