From a52d407ae65cab1d78fc9f15b398c59cc96a2e82 Mon Sep 17 00:00:00 2001 From: Jani Monoses Date: Sat, 3 Feb 2024 14:34:28 +0200 Subject: [PATCH] Add ConvNeXt model. (#1604) --- candle-examples/examples/convnext/README.md | 22 +++ candle-examples/examples/convnext/main.rs | 102 ++++++++++ candle-transformers/src/models/convnext.rs | 201 ++++++++++++++++++++ candle-transformers/src/models/mod.rs | 1 + 4 files changed, 326 insertions(+) create mode 100644 candle-examples/examples/convnext/README.md create mode 100644 candle-examples/examples/convnext/main.rs create mode 100644 candle-transformers/src/models/convnext.rs diff --git a/candle-examples/examples/convnext/README.md b/candle-examples/examples/convnext/README.md new file mode 100644 index 00000000..03d4ef24 --- /dev/null +++ b/candle-examples/examples/convnext/README.md @@ -0,0 +1,22 @@ +# candle-convnext + +[A ConvNet for the 2020s](https://arxiv.org/abs/2201.03545). + +This candle implementation uses a pre-trained ConvNeXt network for inference. The +classification head has been trained on the ImageNet dataset and returns the +probabilities for the top-5 classes. + +## Running an example + +``` +$ cargo run --example convnext --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg --which tiny + +loaded image Tensor[dims 3, 224, 224; f32] +model built +mountain bike, all-terrain bike, off-roader: 84.09% +bicycle-built-for-two, tandem bicycle, tandem: 4.15% +maillot : 0.74% +crash helmet : 0.54% +unicycle, monocycle : 0.44% + +``` diff --git a/candle-examples/examples/convnext/main.rs b/candle-examples/examples/convnext/main.rs new file mode 100644 index 00000000..2ad3e84c --- /dev/null +++ b/candle-examples/examples/convnext/main.rs @@ -0,0 +1,102 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use clap::{Parser, ValueEnum}; + +use candle::{DType, IndexOp, D}; +use candle_nn::{Module, VarBuilder}; +use candle_transformers::models::convnext; + +#[derive(Clone, Copy, Debug, ValueEnum)] +enum Which { + Tiny, + Small, + Base, + Large, + XLarge, +} + +impl Which { + fn model_filename(&self) -> String { + let name = match self { + Self::Tiny => "tiny", + Self::Small => "small", + Self::Base => "base", + Self::Large => "large", + Self::XLarge => "xlarge", + }; + // The XLarge model only has an ImageNet-22K variant + let variant = match self { + Self::XLarge => "fb_in22k_ft_in1k", + _ => "fb_in1k", + }; + + format!("timm/convnext_{name}.{variant}") + } + + fn config(&self) -> convnext::Config { + match self { + Self::Tiny => convnext::Config::tiny(), + Self::Small => convnext::Config::small(), + Self::Base => convnext::Config::base(), + Self::Large => convnext::Config::large(), + Self::XLarge => convnext::Config::xlarge(), + } + } +} + +#[derive(Parser)] +struct Args { + #[arg(long)] + model: Option, + + #[arg(long)] + image: String, + + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, + + #[arg(value_enum, long, default_value_t=Which::Tiny)] + which: Which, +} + +pub fn main() -> anyhow::Result<()> { + let args = Args::parse(); + + let device = candle_examples::device(args.cpu)?; + + let image = candle_examples::imagenet::load_image224(args.image)?; + println!("loaded image {image:?}"); + + let model_file = match args.model { + None => { + let model_name = args.which.model_filename(); + let api = hf_hub::api::sync::Api::new()?; + let api = api.model(model_name); + api.get("model.safetensors")? + } + Some(model) => model.into(), + }; + + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? }; + let model = convnext::convnext(&args.which.config(), 1000, vb)?; + println!("model built"); + let logits = model.forward(&image.unsqueeze(0)?)?; + let prs = candle_nn::ops::softmax(&logits, D::Minus1)? + .i(0)? + .to_vec1::()?; + let mut prs = prs.iter().enumerate().collect::>(); + prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1)); + for &(category_idx, pr) in prs.iter().take(5) { + println!( + "{:24}: {:.2}%", + candle_examples::imagenet::CLASSES[category_idx], + 100. * pr + ); + } + Ok(()) +} diff --git a/candle-transformers/src/models/convnext.rs b/candle-transformers/src/models/convnext.rs new file mode 100644 index 00000000..56bd045c --- /dev/null +++ b/candle-transformers/src/models/convnext.rs @@ -0,0 +1,201 @@ +//! ConvNeXt implementation. +//! +//! See "A ConvNet for the 2020s" Liu et al. 2022 +//! + +//! Original code: https://github.com/facebookresearch/ConvNeXt/ +//! timm: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/convnext.py + +use candle::{Result, D}; +use candle_nn::{conv2d, layer_norm, linear, Conv2dConfig, Func, VarBuilder}; + +#[derive(Clone)] +pub struct Config { + blocks: [usize; 4], + channels: [usize; 4], +} + +impl Config { + pub fn tiny() -> Self { + Self { + blocks: [3, 3, 9, 3], + channels: [96, 192, 384, 768], + } + } + pub fn small() -> Self { + Self { + blocks: [3, 3, 27, 3], + channels: [96, 192, 384, 768], + } + } + pub fn base() -> Self { + Self { + blocks: [3, 3, 27, 3], + channels: [128, 256, 512, 1024], + } + } + pub fn large() -> Self { + Self { + blocks: [3, 3, 27, 3], + channels: [192, 384, 768, 1536], + } + } + + pub fn xlarge() -> Self { + Self { + blocks: [3, 3, 27, 3], + channels: [256, 512, 1024, 2048], + } + } +} + +// Initial downsampling via a patchify layer. +fn convnext_stem(out_channels: usize, vb: VarBuilder) -> Result> { + let conv2d_cfg = Conv2dConfig { + stride: 4, + ..Default::default() + }; + let patchify = conv2d(3, out_channels, 4, conv2d_cfg, vb.pp(0))?; + let norm = layer_norm(out_channels, 1e-6, vb.pp(1))?; + Ok(Func::new(move |xs| { + // The layer norm works with channels-last format. + let xs = xs + .apply(&patchify)? + .permute((0, 2, 3, 1))? + .apply(&norm)? + .permute((0, 3, 1, 2))?; + Ok(xs) + })) +} + +// Downsampling applied after the stages. +fn convnext_downsample(dim: usize, vb: VarBuilder) -> Result> { + let conv2d_cfg = Conv2dConfig { + stride: 2, + ..Default::default() + }; + let norm = layer_norm(dim / 2, 1e-5, vb.pp(0))?; + let conv = conv2d(dim / 2, dim, 2, conv2d_cfg, vb.pp(1))?; + Ok(Func::new(move |xs| { + let xs = xs + .permute((0, 2, 3, 1))? + .apply(&norm)? + .permute((0, 3, 1, 2))? + .apply(&conv)?; + Ok(xs) + })) +} + +// MLP equivalent of pointwise convolutions. +fn convnext_mlp(dim: usize, vb: VarBuilder) -> Result> { + let fc1 = linear(dim, 4 * dim, vb.pp("fc1"))?; + let fc2 = linear(4 * dim, dim, vb.pp("fc2"))?; + + Ok(Func::new(move |xs| { + let xs = xs.apply(&fc1)?.gelu_erf()?.apply(&fc2)?; + Ok(xs) + })) +} + +// A block consisting of a depthwise convolution, a MLP and layer scaling. +fn convnext_block(dim: usize, vb: VarBuilder) -> Result> { + let conv2d_cfg = Conv2dConfig { + groups: dim, + padding: 3, + ..Default::default() + }; + + let conv_dw = conv2d(dim, dim, 7, conv2d_cfg, vb.pp("conv_dw"))?; + + let gamma = vb.get(dim, "gamma")?; + let mlp = convnext_mlp(dim, vb.pp("mlp"))?; + let norm = layer_norm(dim, 1e-6, vb.pp("norm"))?; + + Ok(Func::new(move |xs| { + let residual = xs; + let xs = xs + .apply(&conv_dw)? + .permute((0, 2, 3, 1))? + .apply(&norm)? + .apply(&mlp)? + .broadcast_mul(&gamma)? + .permute((0, 3, 1, 2))?; + + xs + residual + })) +} + +// Each stage contains blocks and a downsampling layer for the previous stage. +fn convnext_stage(cfg: &Config, stage_idx: usize, vb: VarBuilder) -> Result> { + let nblocks = cfg.blocks[stage_idx]; + let mut blocks = Vec::with_capacity(nblocks); + + let dim = cfg.channels[stage_idx]; + + if stage_idx > 0 { + blocks.push(convnext_downsample(dim, vb.pp("downsample"))?); + } + + for block_idx in 0..nblocks { + blocks.push(convnext_block(dim, vb.pp(format!("blocks.{block_idx}")))?); + } + + Ok(Func::new(move |xs| { + let mut xs = xs.clone(); + for block in blocks.iter() { + xs = xs.apply(block)? + } + Ok(xs) + })) +} + +fn convnext_head(outputs: usize, nclasses: usize, vb: VarBuilder) -> Result> { + let norm = layer_norm(outputs, 1e-6, vb.pp("norm"))?; + let linear = linear(outputs, nclasses, vb.pp("fc"))?; + Ok(Func::new(move |xs| xs.apply(&norm)?.apply(&linear))) +} + +// Build a convnext model for a given configuration. +fn convnext_model( + config: &Config, + nclasses: Option, + vb: VarBuilder, +) -> Result> { + let head = match nclasses { + None => None, + Some(nclasses) => { + let head = convnext_head(config.channels[3], nclasses, vb.pp("head"))?; + Some(head) + } + }; + + let stem = convnext_stem(config.channels[0], vb.pp("stem"))?; + let vb = vb.pp("stages"); + let stage1 = convnext_stage(config, 0, vb.pp(0))?; + let stage2 = convnext_stage(config, 1, vb.pp(1))?; + let stage3 = convnext_stage(config, 2, vb.pp(2))?; + let stage4 = convnext_stage(config, 3, vb.pp(3))?; + + Ok(Func::new(move |xs| { + let xs = xs + .apply(&stem)? + .apply(&stage1)? + .apply(&stage2)? + .apply(&stage3)? + .apply(&stage4)? + .mean(D::Minus2)? + .mean(D::Minus1)?; + match &head { + None => Ok(xs), + Some(head) => xs.apply(head), + } + })) +} + +pub fn convnext(cfg: &Config, nclasses: usize, vb: VarBuilder) -> Result> { + convnext_model(cfg, Some(nclasses), vb) +} + +pub fn convnext_no_final_layer(cfg: &Config, vb: VarBuilder) -> Result> { + convnext_model(cfg, None, vb) +} diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index a94fd07a..e92e02e8 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -3,6 +3,7 @@ pub mod bigcode; pub mod blip; pub mod blip_text; pub mod convmixer; +pub mod convnext; pub mod dinov2; pub mod distilbert; pub mod efficientnet;