From 68f76558956f7f56cb5014bb5f7c7c5534436b72 Mon Sep 17 00:00:00 2001 From: Jani Monoses Date: Wed, 14 Feb 2024 11:53:07 +0200 Subject: [PATCH] Add ConvNeXt-V2 and smaller model variants. (#1709) --- candle-examples/examples/convnext/README.md | 3 +- candle-examples/examples/convnext/main.rs | 52 +++-- candle-transformers/src/models/convnext.rs | 210 ++++++++++++++++---- 3 files changed, 214 insertions(+), 51 deletions(-) diff --git a/candle-examples/examples/convnext/README.md b/candle-examples/examples/convnext/README.md index 03d4ef24..d532d7a4 100644 --- a/candle-examples/examples/convnext/README.md +++ b/candle-examples/examples/convnext/README.md @@ -1,6 +1,7 @@ # candle-convnext -[A ConvNet for the 2020s](https://arxiv.org/abs/2201.03545). +[A ConvNet for the 2020s](https://arxiv.org/abs/2201.03545) and +[ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders](https://arxiv.org/abs/2301.00808). This candle implementation uses a pre-trained ConvNeXt network for inference. The classification head has been trained on the ImageNet dataset and returns the diff --git a/candle-examples/examples/convnext/main.rs b/candle-examples/examples/convnext/main.rs index 2ad3e84c..8fc72e16 100644 --- a/candle-examples/examples/convnext/main.rs +++ b/candle-examples/examples/convnext/main.rs @@ -12,38 +12,62 @@ use candle_transformers::models::convnext; #[derive(Clone, Copy, Debug, ValueEnum)] enum Which { + Atto, + Femto, + Pico, + Nano, Tiny, Small, Base, Large, + AttoV2, + FemtoV2, + PicoV2, + NanoV2, + TinyV2, + BaseV2, + LargeV2, XLarge, + Huge, } impl Which { fn model_filename(&self) -> String { let name = match self { - Self::Tiny => "tiny", - Self::Small => "small", - Self::Base => "base", - Self::Large => "large", - Self::XLarge => "xlarge", - }; - // The XLarge model only has an ImageNet-22K variant - let variant = match self { - Self::XLarge => "fb_in22k_ft_in1k", - _ => "fb_in1k", + Self::Atto => "convnext_atto.d2_in1k", + Self::Femto => "convnext_femto.d1_in1k", + Self::Pico => "convnext_pico.d1_in1k", + Self::Nano => "convnext_nano.d1h_in1k", + Self::Tiny => "convnext_tiny.fb_in1k", + Self::Small => "convnext_small.fb_in1k", + Self::Base => "convnext_base.fb_in1k", + Self::Large => "convnext_large.fb_in1k", + Self::AttoV2 => "convnextv2_atto.fcmae_ft_in1k", + Self::FemtoV2 => "convnextv2_femto.fcmae_ft_in1k", + Self::PicoV2 => "convnextv2_pico.fcmae_ft_in1k", + Self::NanoV2 => "convnextv2_nano.fcmae_ft_in1k", + Self::TinyV2 => "convnextv2_tiny.fcmae_ft_in1k", + Self::BaseV2 => "convnextv2_base.fcmae_ft_in1k", + Self::LargeV2 => "convnextv2_large.fcmae_ft_in1k", + Self::XLarge => "convnext_xlarge.fb_in22k_ft_in1k", + Self::Huge => "convnextv2_huge.fcmae_ft_in1k", }; - format!("timm/convnext_{name}.{variant}") + format!("timm/{name}") } fn config(&self) -> convnext::Config { match self { - Self::Tiny => convnext::Config::tiny(), + Self::Atto | Self::AttoV2 => convnext::Config::atto(), + Self::Femto | Self::FemtoV2 => convnext::Config::femto(), + Self::Pico | Self::PicoV2 => convnext::Config::pico(), + Self::Nano | Self::NanoV2 => convnext::Config::nano(), + Self::Tiny | Self::TinyV2 => convnext::Config::tiny(), Self::Small => convnext::Config::small(), - Self::Base => convnext::Config::base(), - Self::Large => convnext::Config::large(), + Self::Base | Self::BaseV2 => convnext::Config::base(), + Self::Large | Self::LargeV2 => convnext::Config::large(), Self::XLarge => convnext::Config::xlarge(), + Self::Huge => convnext::Config::huge(), } } } diff --git a/candle-transformers/src/models/convnext.rs b/candle-transformers/src/models/convnext.rs index 56bd045c..94b1833e 100644 --- a/candle-transformers/src/models/convnext.rs +++ b/candle-transformers/src/models/convnext.rs @@ -2,10 +2,16 @@ //! //! See "A ConvNet for the 2020s" Liu et al. 2022 //! +//! and +//! "ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders" Woo et al. 2023 +//! -//! Original code: https://github.com/facebookresearch/ConvNeXt/ +//! Original code: +//! https://github.com/facebookresearch/ConvNeXt/ +//! https://github.com/facebookresearch/ConvNeXt-V2/ //! timm: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/convnext.py +use candle::shape::ShapeWithOneHole; use candle::{Result, D}; use candle_nn::{conv2d, layer_norm, linear, Conv2dConfig, Func, VarBuilder}; @@ -13,31 +19,71 @@ use candle_nn::{conv2d, layer_norm, linear, Conv2dConfig, Func, VarBuilder}; pub struct Config { blocks: [usize; 4], channels: [usize; 4], + use_conv_mlp: bool, } impl Config { + pub fn atto() -> Self { + Self { + blocks: [2, 2, 6, 2], + channels: [40, 80, 160, 320], + use_conv_mlp: true, + } + } + + pub fn femto() -> Self { + Self { + blocks: [2, 2, 6, 2], + channels: [48, 96, 192, 384], + use_conv_mlp: true, + } + } + + pub fn pico() -> Self { + Self { + blocks: [2, 2, 6, 2], + channels: [64, 128, 256, 512], + use_conv_mlp: true, + } + } + + pub fn nano() -> Self { + Self { + blocks: [2, 2, 8, 2], + channels: [80, 160, 320, 640], + use_conv_mlp: true, + } + } + pub fn tiny() -> Self { Self { blocks: [3, 3, 9, 3], channels: [96, 192, 384, 768], + use_conv_mlp: false, } } + pub fn small() -> Self { Self { blocks: [3, 3, 27, 3], channels: [96, 192, 384, 768], + use_conv_mlp: false, } } + pub fn base() -> Self { Self { blocks: [3, 3, 27, 3], channels: [128, 256, 512, 1024], + use_conv_mlp: false, } } + pub fn large() -> Self { Self { blocks: [3, 3, 27, 3], channels: [192, 384, 768, 1536], + use_conv_mlp: false, } } @@ -45,8 +91,68 @@ impl Config { Self { blocks: [3, 3, 27, 3], channels: [256, 512, 1024, 2048], + use_conv_mlp: false, } } + + pub fn huge() -> Self { + Self { + blocks: [3, 3, 27, 3], + channels: [352, 704, 1408, 2816], + use_conv_mlp: false, + } + } +} + +// Layer norm for data in channels-last format. +fn layer_norm_cl(dim: usize, vb: VarBuilder) -> Result> { + let norm = layer_norm(dim, 1e-6, vb)?; + + Ok(Func::new(move |xs| xs.apply(&norm))) +} + +// Layer norm for data in channels-first format. +fn layer_norm_cf(dim: usize, vb: VarBuilder) -> Result> { + let norm = layer_norm(dim, 1e-6, vb)?; + + Ok(Func::new(move |xs| { + let xs = xs + .permute((0, 2, 3, 1))? + .apply(&norm)? + .permute((0, 3, 1, 2))?; + Ok(xs) + })) +} + +// Global response normalization layer +// Based on https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/grn.py +fn convnext2_grn(dim: usize, channels_last: bool, vb: VarBuilder) -> Result> { + let (shape, spatial_dim, channel_dim) = if channels_last { + ((1, 1, 1, ()).into_shape(dim)?, [1, 2], 3) + } else { + ((1, (), 1, 1).into_shape(dim)?, [2, 3], 1) + }; + + let gamma = vb.get(dim, "weight")?.reshape(&shape)?; + let beta = vb.get(dim, "bias")?.reshape(&shape)?; + + Ok(Func::new(move |xs| { + let residual = xs; + let gx = xs + .sqr()? + .sum_keepdim(spatial_dim)? + .mean_keepdim(spatial_dim)? + .sqrt()?; + + let gxmean = gx.mean_keepdim(channel_dim)?; + let nx = gx.broadcast_div(&(gxmean + 1e-6)?)?; + let xs = xs + .broadcast_mul(&nx)? + .broadcast_mul(&gamma)? + .broadcast_add(&beta)?; + + xs + residual + })) } // Initial downsampling via a patchify layer. @@ -56,16 +162,9 @@ fn convnext_stem(out_channels: usize, vb: VarBuilder) -> Result> { ..Default::default() }; let patchify = conv2d(3, out_channels, 4, conv2d_cfg, vb.pp(0))?; - let norm = layer_norm(out_channels, 1e-6, vb.pp(1))?; - Ok(Func::new(move |xs| { - // The layer norm works with channels-last format. - let xs = xs - .apply(&patchify)? - .permute((0, 2, 3, 1))? - .apply(&norm)? - .permute((0, 3, 1, 2))?; - Ok(xs) - })) + let norm = layer_norm_cf(out_channels, vb.pp(1))?; + + Ok(Func::new(move |xs| xs.apply(&patchify)?.apply(&norm))) } // Downsampling applied after the stages. @@ -74,31 +173,49 @@ fn convnext_downsample(dim: usize, vb: VarBuilder) -> Result> { stride: 2, ..Default::default() }; - let norm = layer_norm(dim / 2, 1e-5, vb.pp(0))?; + let norm = layer_norm_cf(dim / 2, vb.pp(0))?; let conv = conv2d(dim / 2, dim, 2, conv2d_cfg, vb.pp(1))?; - Ok(Func::new(move |xs| { - let xs = xs - .permute((0, 2, 3, 1))? - .apply(&norm)? - .permute((0, 3, 1, 2))? - .apply(&conv)?; - Ok(xs) - })) + + Ok(Func::new(move |xs| xs.apply(&norm)?.apply(&conv))) } -// MLP equivalent of pointwise convolutions. +// MLP block from the original paper with optional GRN layer (v2 models). fn convnext_mlp(dim: usize, vb: VarBuilder) -> Result> { let fc1 = linear(dim, 4 * dim, vb.pp("fc1"))?; let fc2 = linear(4 * dim, dim, vb.pp("fc2"))?; + let grn = convnext2_grn(4 * dim, true, vb.pp("grn")); Ok(Func::new(move |xs| { - let xs = xs.apply(&fc1)?.gelu_erf()?.apply(&fc2)?; + let mut xs = xs.apply(&fc1)?.gelu_erf()?; + if let Ok(g) = &grn { + xs = xs.apply(g)?; + } + xs = xs.apply(&fc2)?; Ok(xs) })) } -// A block consisting of a depthwise convolution, a MLP and layer scaling. -fn convnext_block(dim: usize, vb: VarBuilder) -> Result> { +// MLP block using pointwise convolutions, with optional GRN layer (v2 models). +fn convnext_conv_mlp(dim: usize, vb: VarBuilder) -> Result> { + let conv2d_cfg = Conv2dConfig { + ..Default::default() + }; + let fc1 = conv2d(dim, 4 * dim, 1, conv2d_cfg, vb.pp("fc1"))?; + let fc2 = conv2d(4 * dim, dim, 1, conv2d_cfg, vb.pp("fc2"))?; + + let grn = convnext2_grn(4 * dim, false, vb.pp("grn")); + Ok(Func::new(move |xs| { + let mut xs = xs.apply(&fc1)?.gelu_erf()?; + if let Ok(g) = &grn { + xs = xs.apply(g)?; + } + xs = xs.apply(&fc2)?; + Ok(xs) + })) +} + +// A block consisting of a depthwise convolution, a MLP and layer scaling (v1 models only). +fn convnext_block(dim: usize, use_conv_mlp: bool, vb: VarBuilder) -> Result> { let conv2d_cfg = Conv2dConfig { groups: dim, padding: 3, @@ -106,20 +223,36 @@ fn convnext_block(dim: usize, vb: VarBuilder) -> Result> { }; let conv_dw = conv2d(dim, dim, 7, conv2d_cfg, vb.pp("conv_dw"))?; + let gamma = vb.get(dim, "gamma"); - let gamma = vb.get(dim, "gamma")?; - let mlp = convnext_mlp(dim, vb.pp("mlp"))?; - let norm = layer_norm(dim, 1e-6, vb.pp("norm"))?; + let (mlp, norm) = if use_conv_mlp { + ( + convnext_conv_mlp(dim, vb.pp("mlp"))?, + layer_norm_cf(dim, vb.pp("norm"))?, + ) + } else { + ( + convnext_mlp(dim, vb.pp("mlp"))?, + layer_norm_cl(dim, vb.pp("norm"))?, + ) + }; Ok(Func::new(move |xs| { let residual = xs; - let xs = xs - .apply(&conv_dw)? - .permute((0, 2, 3, 1))? - .apply(&norm)? - .apply(&mlp)? - .broadcast_mul(&gamma)? - .permute((0, 3, 1, 2))?; + let mut xs = xs.apply(&conv_dw)?; + + xs = if use_conv_mlp { + xs.apply(&norm)?.apply(&mlp)? + } else { + xs.permute((0, 2, 3, 1))? + .apply(&norm)? + .apply(&mlp)? + .permute((0, 3, 1, 2))? + }; + + if let Ok(g) = &gamma { + xs = xs.broadcast_mul(&g.reshape((1, (), 1, 1))?)?; + }; xs + residual })) @@ -137,7 +270,11 @@ fn convnext_stage(cfg: &Config, stage_idx: usize, vb: VarBuilder) -> Result Result Result> { - let norm = layer_norm(outputs, 1e-6, vb.pp("norm"))?; + let norm = layer_norm_cl(outputs, vb.pp("norm"))?; let linear = linear(outputs, nclasses, vb.pp("fc"))?; Ok(Func::new(move |xs| xs.apply(&norm)?.apply(&linear))) }