mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Add ConvNeXt-V2 and smaller model variants. (#1709)
This commit is contained in:
@ -1,6 +1,7 @@
|
|||||||
# candle-convnext
|
# candle-convnext
|
||||||
|
|
||||||
[A ConvNet for the 2020s](https://arxiv.org/abs/2201.03545).
|
[A ConvNet for the 2020s](https://arxiv.org/abs/2201.03545) and
|
||||||
|
[ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders](https://arxiv.org/abs/2301.00808).
|
||||||
|
|
||||||
This candle implementation uses a pre-trained ConvNeXt network for inference. The
|
This candle implementation uses a pre-trained ConvNeXt network for inference. The
|
||||||
classification head has been trained on the ImageNet dataset and returns the
|
classification head has been trained on the ImageNet dataset and returns the
|
||||||
|
@ -12,38 +12,62 @@ use candle_transformers::models::convnext;
|
|||||||
|
|
||||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||||
enum Which {
|
enum Which {
|
||||||
|
Atto,
|
||||||
|
Femto,
|
||||||
|
Pico,
|
||||||
|
Nano,
|
||||||
Tiny,
|
Tiny,
|
||||||
Small,
|
Small,
|
||||||
Base,
|
Base,
|
||||||
Large,
|
Large,
|
||||||
|
AttoV2,
|
||||||
|
FemtoV2,
|
||||||
|
PicoV2,
|
||||||
|
NanoV2,
|
||||||
|
TinyV2,
|
||||||
|
BaseV2,
|
||||||
|
LargeV2,
|
||||||
XLarge,
|
XLarge,
|
||||||
|
Huge,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Which {
|
impl Which {
|
||||||
fn model_filename(&self) -> String {
|
fn model_filename(&self) -> String {
|
||||||
let name = match self {
|
let name = match self {
|
||||||
Self::Tiny => "tiny",
|
Self::Atto => "convnext_atto.d2_in1k",
|
||||||
Self::Small => "small",
|
Self::Femto => "convnext_femto.d1_in1k",
|
||||||
Self::Base => "base",
|
Self::Pico => "convnext_pico.d1_in1k",
|
||||||
Self::Large => "large",
|
Self::Nano => "convnext_nano.d1h_in1k",
|
||||||
Self::XLarge => "xlarge",
|
Self::Tiny => "convnext_tiny.fb_in1k",
|
||||||
};
|
Self::Small => "convnext_small.fb_in1k",
|
||||||
// The XLarge model only has an ImageNet-22K variant
|
Self::Base => "convnext_base.fb_in1k",
|
||||||
let variant = match self {
|
Self::Large => "convnext_large.fb_in1k",
|
||||||
Self::XLarge => "fb_in22k_ft_in1k",
|
Self::AttoV2 => "convnextv2_atto.fcmae_ft_in1k",
|
||||||
_ => "fb_in1k",
|
Self::FemtoV2 => "convnextv2_femto.fcmae_ft_in1k",
|
||||||
|
Self::PicoV2 => "convnextv2_pico.fcmae_ft_in1k",
|
||||||
|
Self::NanoV2 => "convnextv2_nano.fcmae_ft_in1k",
|
||||||
|
Self::TinyV2 => "convnextv2_tiny.fcmae_ft_in1k",
|
||||||
|
Self::BaseV2 => "convnextv2_base.fcmae_ft_in1k",
|
||||||
|
Self::LargeV2 => "convnextv2_large.fcmae_ft_in1k",
|
||||||
|
Self::XLarge => "convnext_xlarge.fb_in22k_ft_in1k",
|
||||||
|
Self::Huge => "convnextv2_huge.fcmae_ft_in1k",
|
||||||
};
|
};
|
||||||
|
|
||||||
format!("timm/convnext_{name}.{variant}")
|
format!("timm/{name}")
|
||||||
}
|
}
|
||||||
|
|
||||||
fn config(&self) -> convnext::Config {
|
fn config(&self) -> convnext::Config {
|
||||||
match self {
|
match self {
|
||||||
Self::Tiny => convnext::Config::tiny(),
|
Self::Atto | Self::AttoV2 => convnext::Config::atto(),
|
||||||
|
Self::Femto | Self::FemtoV2 => convnext::Config::femto(),
|
||||||
|
Self::Pico | Self::PicoV2 => convnext::Config::pico(),
|
||||||
|
Self::Nano | Self::NanoV2 => convnext::Config::nano(),
|
||||||
|
Self::Tiny | Self::TinyV2 => convnext::Config::tiny(),
|
||||||
Self::Small => convnext::Config::small(),
|
Self::Small => convnext::Config::small(),
|
||||||
Self::Base => convnext::Config::base(),
|
Self::Base | Self::BaseV2 => convnext::Config::base(),
|
||||||
Self::Large => convnext::Config::large(),
|
Self::Large | Self::LargeV2 => convnext::Config::large(),
|
||||||
Self::XLarge => convnext::Config::xlarge(),
|
Self::XLarge => convnext::Config::xlarge(),
|
||||||
|
Self::Huge => convnext::Config::huge(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -2,10 +2,16 @@
|
|||||||
//!
|
//!
|
||||||
//! See "A ConvNet for the 2020s" Liu et al. 2022
|
//! See "A ConvNet for the 2020s" Liu et al. 2022
|
||||||
//! <https://arxiv.org/abs/2201.03545>
|
//! <https://arxiv.org/abs/2201.03545>
|
||||||
|
//! and
|
||||||
|
//! "ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders" Woo et al. 2023
|
||||||
|
//! <https://arxiv.org/abs/2301.00808>
|
||||||
|
|
||||||
//! Original code: https://github.com/facebookresearch/ConvNeXt/
|
//! Original code:
|
||||||
|
//! https://github.com/facebookresearch/ConvNeXt/
|
||||||
|
//! https://github.com/facebookresearch/ConvNeXt-V2/
|
||||||
//! timm: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/convnext.py
|
//! timm: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/convnext.py
|
||||||
|
|
||||||
|
use candle::shape::ShapeWithOneHole;
|
||||||
use candle::{Result, D};
|
use candle::{Result, D};
|
||||||
use candle_nn::{conv2d, layer_norm, linear, Conv2dConfig, Func, VarBuilder};
|
use candle_nn::{conv2d, layer_norm, linear, Conv2dConfig, Func, VarBuilder};
|
||||||
|
|
||||||
@ -13,31 +19,71 @@ use candle_nn::{conv2d, layer_norm, linear, Conv2dConfig, Func, VarBuilder};
|
|||||||
pub struct Config {
|
pub struct Config {
|
||||||
blocks: [usize; 4],
|
blocks: [usize; 4],
|
||||||
channels: [usize; 4],
|
channels: [usize; 4],
|
||||||
|
use_conv_mlp: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Config {
|
impl Config {
|
||||||
|
pub fn atto() -> Self {
|
||||||
|
Self {
|
||||||
|
blocks: [2, 2, 6, 2],
|
||||||
|
channels: [40, 80, 160, 320],
|
||||||
|
use_conv_mlp: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn femto() -> Self {
|
||||||
|
Self {
|
||||||
|
blocks: [2, 2, 6, 2],
|
||||||
|
channels: [48, 96, 192, 384],
|
||||||
|
use_conv_mlp: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn pico() -> Self {
|
||||||
|
Self {
|
||||||
|
blocks: [2, 2, 6, 2],
|
||||||
|
channels: [64, 128, 256, 512],
|
||||||
|
use_conv_mlp: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn nano() -> Self {
|
||||||
|
Self {
|
||||||
|
blocks: [2, 2, 8, 2],
|
||||||
|
channels: [80, 160, 320, 640],
|
||||||
|
use_conv_mlp: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn tiny() -> Self {
|
pub fn tiny() -> Self {
|
||||||
Self {
|
Self {
|
||||||
blocks: [3, 3, 9, 3],
|
blocks: [3, 3, 9, 3],
|
||||||
channels: [96, 192, 384, 768],
|
channels: [96, 192, 384, 768],
|
||||||
|
use_conv_mlp: false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn small() -> Self {
|
pub fn small() -> Self {
|
||||||
Self {
|
Self {
|
||||||
blocks: [3, 3, 27, 3],
|
blocks: [3, 3, 27, 3],
|
||||||
channels: [96, 192, 384, 768],
|
channels: [96, 192, 384, 768],
|
||||||
|
use_conv_mlp: false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn base() -> Self {
|
pub fn base() -> Self {
|
||||||
Self {
|
Self {
|
||||||
blocks: [3, 3, 27, 3],
|
blocks: [3, 3, 27, 3],
|
||||||
channels: [128, 256, 512, 1024],
|
channels: [128, 256, 512, 1024],
|
||||||
|
use_conv_mlp: false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn large() -> Self {
|
pub fn large() -> Self {
|
||||||
Self {
|
Self {
|
||||||
blocks: [3, 3, 27, 3],
|
blocks: [3, 3, 27, 3],
|
||||||
channels: [192, 384, 768, 1536],
|
channels: [192, 384, 768, 1536],
|
||||||
|
use_conv_mlp: false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -45,8 +91,68 @@ impl Config {
|
|||||||
Self {
|
Self {
|
||||||
blocks: [3, 3, 27, 3],
|
blocks: [3, 3, 27, 3],
|
||||||
channels: [256, 512, 1024, 2048],
|
channels: [256, 512, 1024, 2048],
|
||||||
|
use_conv_mlp: false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn huge() -> Self {
|
||||||
|
Self {
|
||||||
|
blocks: [3, 3, 27, 3],
|
||||||
|
channels: [352, 704, 1408, 2816],
|
||||||
|
use_conv_mlp: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Layer norm for data in channels-last format.
|
||||||
|
fn layer_norm_cl(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||||
|
let norm = layer_norm(dim, 1e-6, vb)?;
|
||||||
|
|
||||||
|
Ok(Func::new(move |xs| xs.apply(&norm)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Layer norm for data in channels-first format.
|
||||||
|
fn layer_norm_cf(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||||
|
let norm = layer_norm(dim, 1e-6, vb)?;
|
||||||
|
|
||||||
|
Ok(Func::new(move |xs| {
|
||||||
|
let xs = xs
|
||||||
|
.permute((0, 2, 3, 1))?
|
||||||
|
.apply(&norm)?
|
||||||
|
.permute((0, 3, 1, 2))?;
|
||||||
|
Ok(xs)
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Global response normalization layer
|
||||||
|
// Based on https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/grn.py
|
||||||
|
fn convnext2_grn(dim: usize, channels_last: bool, vb: VarBuilder) -> Result<Func<'static>> {
|
||||||
|
let (shape, spatial_dim, channel_dim) = if channels_last {
|
||||||
|
((1, 1, 1, ()).into_shape(dim)?, [1, 2], 3)
|
||||||
|
} else {
|
||||||
|
((1, (), 1, 1).into_shape(dim)?, [2, 3], 1)
|
||||||
|
};
|
||||||
|
|
||||||
|
let gamma = vb.get(dim, "weight")?.reshape(&shape)?;
|
||||||
|
let beta = vb.get(dim, "bias")?.reshape(&shape)?;
|
||||||
|
|
||||||
|
Ok(Func::new(move |xs| {
|
||||||
|
let residual = xs;
|
||||||
|
let gx = xs
|
||||||
|
.sqr()?
|
||||||
|
.sum_keepdim(spatial_dim)?
|
||||||
|
.mean_keepdim(spatial_dim)?
|
||||||
|
.sqrt()?;
|
||||||
|
|
||||||
|
let gxmean = gx.mean_keepdim(channel_dim)?;
|
||||||
|
let nx = gx.broadcast_div(&(gxmean + 1e-6)?)?;
|
||||||
|
let xs = xs
|
||||||
|
.broadcast_mul(&nx)?
|
||||||
|
.broadcast_mul(&gamma)?
|
||||||
|
.broadcast_add(&beta)?;
|
||||||
|
|
||||||
|
xs + residual
|
||||||
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initial downsampling via a patchify layer.
|
// Initial downsampling via a patchify layer.
|
||||||
@ -56,16 +162,9 @@ fn convnext_stem(out_channels: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
|||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
let patchify = conv2d(3, out_channels, 4, conv2d_cfg, vb.pp(0))?;
|
let patchify = conv2d(3, out_channels, 4, conv2d_cfg, vb.pp(0))?;
|
||||||
let norm = layer_norm(out_channels, 1e-6, vb.pp(1))?;
|
let norm = layer_norm_cf(out_channels, vb.pp(1))?;
|
||||||
Ok(Func::new(move |xs| {
|
|
||||||
// The layer norm works with channels-last format.
|
Ok(Func::new(move |xs| xs.apply(&patchify)?.apply(&norm)))
|
||||||
let xs = xs
|
|
||||||
.apply(&patchify)?
|
|
||||||
.permute((0, 2, 3, 1))?
|
|
||||||
.apply(&norm)?
|
|
||||||
.permute((0, 3, 1, 2))?;
|
|
||||||
Ok(xs)
|
|
||||||
}))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Downsampling applied after the stages.
|
// Downsampling applied after the stages.
|
||||||
@ -74,31 +173,49 @@ fn convnext_downsample(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
|||||||
stride: 2,
|
stride: 2,
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
let norm = layer_norm(dim / 2, 1e-5, vb.pp(0))?;
|
let norm = layer_norm_cf(dim / 2, vb.pp(0))?;
|
||||||
let conv = conv2d(dim / 2, dim, 2, conv2d_cfg, vb.pp(1))?;
|
let conv = conv2d(dim / 2, dim, 2, conv2d_cfg, vb.pp(1))?;
|
||||||
Ok(Func::new(move |xs| {
|
|
||||||
let xs = xs
|
Ok(Func::new(move |xs| xs.apply(&norm)?.apply(&conv)))
|
||||||
.permute((0, 2, 3, 1))?
|
|
||||||
.apply(&norm)?
|
|
||||||
.permute((0, 3, 1, 2))?
|
|
||||||
.apply(&conv)?;
|
|
||||||
Ok(xs)
|
|
||||||
}))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// MLP equivalent of pointwise convolutions.
|
// MLP block from the original paper with optional GRN layer (v2 models).
|
||||||
fn convnext_mlp(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
fn convnext_mlp(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||||
let fc1 = linear(dim, 4 * dim, vb.pp("fc1"))?;
|
let fc1 = linear(dim, 4 * dim, vb.pp("fc1"))?;
|
||||||
let fc2 = linear(4 * dim, dim, vb.pp("fc2"))?;
|
let fc2 = linear(4 * dim, dim, vb.pp("fc2"))?;
|
||||||
|
let grn = convnext2_grn(4 * dim, true, vb.pp("grn"));
|
||||||
|
|
||||||
Ok(Func::new(move |xs| {
|
Ok(Func::new(move |xs| {
|
||||||
let xs = xs.apply(&fc1)?.gelu_erf()?.apply(&fc2)?;
|
let mut xs = xs.apply(&fc1)?.gelu_erf()?;
|
||||||
|
if let Ok(g) = &grn {
|
||||||
|
xs = xs.apply(g)?;
|
||||||
|
}
|
||||||
|
xs = xs.apply(&fc2)?;
|
||||||
Ok(xs)
|
Ok(xs)
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
// A block consisting of a depthwise convolution, a MLP and layer scaling.
|
// MLP block using pointwise convolutions, with optional GRN layer (v2 models).
|
||||||
fn convnext_block(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
fn convnext_conv_mlp(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||||
|
let conv2d_cfg = Conv2dConfig {
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let fc1 = conv2d(dim, 4 * dim, 1, conv2d_cfg, vb.pp("fc1"))?;
|
||||||
|
let fc2 = conv2d(4 * dim, dim, 1, conv2d_cfg, vb.pp("fc2"))?;
|
||||||
|
|
||||||
|
let grn = convnext2_grn(4 * dim, false, vb.pp("grn"));
|
||||||
|
Ok(Func::new(move |xs| {
|
||||||
|
let mut xs = xs.apply(&fc1)?.gelu_erf()?;
|
||||||
|
if let Ok(g) = &grn {
|
||||||
|
xs = xs.apply(g)?;
|
||||||
|
}
|
||||||
|
xs = xs.apply(&fc2)?;
|
||||||
|
Ok(xs)
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
// A block consisting of a depthwise convolution, a MLP and layer scaling (v1 models only).
|
||||||
|
fn convnext_block(dim: usize, use_conv_mlp: bool, vb: VarBuilder) -> Result<Func<'static>> {
|
||||||
let conv2d_cfg = Conv2dConfig {
|
let conv2d_cfg = Conv2dConfig {
|
||||||
groups: dim,
|
groups: dim,
|
||||||
padding: 3,
|
padding: 3,
|
||||||
@ -106,20 +223,36 @@ fn convnext_block(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
|||||||
};
|
};
|
||||||
|
|
||||||
let conv_dw = conv2d(dim, dim, 7, conv2d_cfg, vb.pp("conv_dw"))?;
|
let conv_dw = conv2d(dim, dim, 7, conv2d_cfg, vb.pp("conv_dw"))?;
|
||||||
|
let gamma = vb.get(dim, "gamma");
|
||||||
|
|
||||||
let gamma = vb.get(dim, "gamma")?;
|
let (mlp, norm) = if use_conv_mlp {
|
||||||
let mlp = convnext_mlp(dim, vb.pp("mlp"))?;
|
(
|
||||||
let norm = layer_norm(dim, 1e-6, vb.pp("norm"))?;
|
convnext_conv_mlp(dim, vb.pp("mlp"))?,
|
||||||
|
layer_norm_cf(dim, vb.pp("norm"))?,
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
(
|
||||||
|
convnext_mlp(dim, vb.pp("mlp"))?,
|
||||||
|
layer_norm_cl(dim, vb.pp("norm"))?,
|
||||||
|
)
|
||||||
|
};
|
||||||
|
|
||||||
Ok(Func::new(move |xs| {
|
Ok(Func::new(move |xs| {
|
||||||
let residual = xs;
|
let residual = xs;
|
||||||
let xs = xs
|
let mut xs = xs.apply(&conv_dw)?;
|
||||||
.apply(&conv_dw)?
|
|
||||||
.permute((0, 2, 3, 1))?
|
xs = if use_conv_mlp {
|
||||||
.apply(&norm)?
|
xs.apply(&norm)?.apply(&mlp)?
|
||||||
.apply(&mlp)?
|
} else {
|
||||||
.broadcast_mul(&gamma)?
|
xs.permute((0, 2, 3, 1))?
|
||||||
.permute((0, 3, 1, 2))?;
|
.apply(&norm)?
|
||||||
|
.apply(&mlp)?
|
||||||
|
.permute((0, 3, 1, 2))?
|
||||||
|
};
|
||||||
|
|
||||||
|
if let Ok(g) = &gamma {
|
||||||
|
xs = xs.broadcast_mul(&g.reshape((1, (), 1, 1))?)?;
|
||||||
|
};
|
||||||
|
|
||||||
xs + residual
|
xs + residual
|
||||||
}))
|
}))
|
||||||
@ -137,7 +270,11 @@ fn convnext_stage(cfg: &Config, stage_idx: usize, vb: VarBuilder) -> Result<Func
|
|||||||
}
|
}
|
||||||
|
|
||||||
for block_idx in 0..nblocks {
|
for block_idx in 0..nblocks {
|
||||||
blocks.push(convnext_block(dim, vb.pp(format!("blocks.{block_idx}")))?);
|
blocks.push(convnext_block(
|
||||||
|
dim,
|
||||||
|
cfg.use_conv_mlp,
|
||||||
|
vb.pp(format!("blocks.{block_idx}")),
|
||||||
|
)?);
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(Func::new(move |xs| {
|
Ok(Func::new(move |xs| {
|
||||||
@ -149,8 +286,9 @@ fn convnext_stage(cfg: &Config, stage_idx: usize, vb: VarBuilder) -> Result<Func
|
|||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Classification head.
|
||||||
fn convnext_head(outputs: usize, nclasses: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
fn convnext_head(outputs: usize, nclasses: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||||
let norm = layer_norm(outputs, 1e-6, vb.pp("norm"))?;
|
let norm = layer_norm_cl(outputs, vb.pp("norm"))?;
|
||||||
let linear = linear(outputs, nclasses, vb.pp("fc"))?;
|
let linear = linear(outputs, nclasses, vb.pp("fc"))?;
|
||||||
Ok(Func::new(move |xs| xs.apply(&norm)?.apply(&linear)))
|
Ok(Func::new(move |xs| xs.apply(&norm)?.apply(&linear)))
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user