Add ConvNeXt-V2 and smaller model variants. (#1709)

This commit is contained in:
Jani Monoses
2024-02-14 11:53:07 +02:00
committed by GitHub
parent b60064780d
commit 68f7655895
3 changed files with 214 additions and 51 deletions

View File

@ -1,6 +1,7 @@
# candle-convnext
[A ConvNet for the 2020s](https://arxiv.org/abs/2201.03545).
[A ConvNet for the 2020s](https://arxiv.org/abs/2201.03545) and
[ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders](https://arxiv.org/abs/2301.00808).
This candle implementation uses a pre-trained ConvNeXt network for inference. The
classification head has been trained on the ImageNet dataset and returns the

View File

@ -12,38 +12,62 @@ use candle_transformers::models::convnext;
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Which {
Atto,
Femto,
Pico,
Nano,
Tiny,
Small,
Base,
Large,
AttoV2,
FemtoV2,
PicoV2,
NanoV2,
TinyV2,
BaseV2,
LargeV2,
XLarge,
Huge,
}
impl Which {
fn model_filename(&self) -> String {
let name = match self {
Self::Tiny => "tiny",
Self::Small => "small",
Self::Base => "base",
Self::Large => "large",
Self::XLarge => "xlarge",
};
// The XLarge model only has an ImageNet-22K variant
let variant = match self {
Self::XLarge => "fb_in22k_ft_in1k",
_ => "fb_in1k",
Self::Atto => "convnext_atto.d2_in1k",
Self::Femto => "convnext_femto.d1_in1k",
Self::Pico => "convnext_pico.d1_in1k",
Self::Nano => "convnext_nano.d1h_in1k",
Self::Tiny => "convnext_tiny.fb_in1k",
Self::Small => "convnext_small.fb_in1k",
Self::Base => "convnext_base.fb_in1k",
Self::Large => "convnext_large.fb_in1k",
Self::AttoV2 => "convnextv2_atto.fcmae_ft_in1k",
Self::FemtoV2 => "convnextv2_femto.fcmae_ft_in1k",
Self::PicoV2 => "convnextv2_pico.fcmae_ft_in1k",
Self::NanoV2 => "convnextv2_nano.fcmae_ft_in1k",
Self::TinyV2 => "convnextv2_tiny.fcmae_ft_in1k",
Self::BaseV2 => "convnextv2_base.fcmae_ft_in1k",
Self::LargeV2 => "convnextv2_large.fcmae_ft_in1k",
Self::XLarge => "convnext_xlarge.fb_in22k_ft_in1k",
Self::Huge => "convnextv2_huge.fcmae_ft_in1k",
};
format!("timm/convnext_{name}.{variant}")
format!("timm/{name}")
}
fn config(&self) -> convnext::Config {
match self {
Self::Tiny => convnext::Config::tiny(),
Self::Atto | Self::AttoV2 => convnext::Config::atto(),
Self::Femto | Self::FemtoV2 => convnext::Config::femto(),
Self::Pico | Self::PicoV2 => convnext::Config::pico(),
Self::Nano | Self::NanoV2 => convnext::Config::nano(),
Self::Tiny | Self::TinyV2 => convnext::Config::tiny(),
Self::Small => convnext::Config::small(),
Self::Base => convnext::Config::base(),
Self::Large => convnext::Config::large(),
Self::Base | Self::BaseV2 => convnext::Config::base(),
Self::Large | Self::LargeV2 => convnext::Config::large(),
Self::XLarge => convnext::Config::xlarge(),
Self::Huge => convnext::Config::huge(),
}
}
}