mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Add ConvNeXt model. (#1604)
This commit is contained in:
22
candle-examples/examples/convnext/README.md
Normal file
22
candle-examples/examples/convnext/README.md
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
# candle-convnext
|
||||||
|
|
||||||
|
[A ConvNet for the 2020s](https://arxiv.org/abs/2201.03545).
|
||||||
|
|
||||||
|
This candle implementation uses a pre-trained ConvNeXt network for inference. The
|
||||||
|
classification head has been trained on the ImageNet dataset and returns the
|
||||||
|
probabilities for the top-5 classes.
|
||||||
|
|
||||||
|
## Running an example
|
||||||
|
|
||||||
|
```
|
||||||
|
$ cargo run --example convnext --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg --which tiny
|
||||||
|
|
||||||
|
loaded image Tensor[dims 3, 224, 224; f32]
|
||||||
|
model built
|
||||||
|
mountain bike, all-terrain bike, off-roader: 84.09%
|
||||||
|
bicycle-built-for-two, tandem bicycle, tandem: 4.15%
|
||||||
|
maillot : 0.74%
|
||||||
|
crash helmet : 0.54%
|
||||||
|
unicycle, monocycle : 0.44%
|
||||||
|
|
||||||
|
```
|
102
candle-examples/examples/convnext/main.rs
Normal file
102
candle-examples/examples/convnext/main.rs
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
|
||||||
|
use clap::{Parser, ValueEnum};
|
||||||
|
|
||||||
|
use candle::{DType, IndexOp, D};
|
||||||
|
use candle_nn::{Module, VarBuilder};
|
||||||
|
use candle_transformers::models::convnext;
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||||
|
enum Which {
|
||||||
|
Tiny,
|
||||||
|
Small,
|
||||||
|
Base,
|
||||||
|
Large,
|
||||||
|
XLarge,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Which {
|
||||||
|
fn model_filename(&self) -> String {
|
||||||
|
let name = match self {
|
||||||
|
Self::Tiny => "tiny",
|
||||||
|
Self::Small => "small",
|
||||||
|
Self::Base => "base",
|
||||||
|
Self::Large => "large",
|
||||||
|
Self::XLarge => "xlarge",
|
||||||
|
};
|
||||||
|
// The XLarge model only has an ImageNet-22K variant
|
||||||
|
let variant = match self {
|
||||||
|
Self::XLarge => "fb_in22k_ft_in1k",
|
||||||
|
_ => "fb_in1k",
|
||||||
|
};
|
||||||
|
|
||||||
|
format!("timm/convnext_{name}.{variant}")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn config(&self) -> convnext::Config {
|
||||||
|
match self {
|
||||||
|
Self::Tiny => convnext::Config::tiny(),
|
||||||
|
Self::Small => convnext::Config::small(),
|
||||||
|
Self::Base => convnext::Config::base(),
|
||||||
|
Self::Large => convnext::Config::large(),
|
||||||
|
Self::XLarge => convnext::Config::xlarge(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Parser)]
|
||||||
|
struct Args {
|
||||||
|
#[arg(long)]
|
||||||
|
model: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
image: String,
|
||||||
|
|
||||||
|
/// Run on CPU rather than on GPU.
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
|
||||||
|
#[arg(value_enum, long, default_value_t=Which::Tiny)]
|
||||||
|
which: Which,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn main() -> anyhow::Result<()> {
|
||||||
|
let args = Args::parse();
|
||||||
|
|
||||||
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
|
||||||
|
let image = candle_examples::imagenet::load_image224(args.image)?;
|
||||||
|
println!("loaded image {image:?}");
|
||||||
|
|
||||||
|
let model_file = match args.model {
|
||||||
|
None => {
|
||||||
|
let model_name = args.which.model_filename();
|
||||||
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
|
let api = api.model(model_name);
|
||||||
|
api.get("model.safetensors")?
|
||||||
|
}
|
||||||
|
Some(model) => model.into(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
|
||||||
|
let model = convnext::convnext(&args.which.config(), 1000, vb)?;
|
||||||
|
println!("model built");
|
||||||
|
let logits = model.forward(&image.unsqueeze(0)?)?;
|
||||||
|
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
|
||||||
|
.i(0)?
|
||||||
|
.to_vec1::<f32>()?;
|
||||||
|
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
|
||||||
|
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
|
||||||
|
for &(category_idx, pr) in prs.iter().take(5) {
|
||||||
|
println!(
|
||||||
|
"{:24}: {:.2}%",
|
||||||
|
candle_examples::imagenet::CLASSES[category_idx],
|
||||||
|
100. * pr
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
201
candle-transformers/src/models/convnext.rs
Normal file
201
candle-transformers/src/models/convnext.rs
Normal file
@ -0,0 +1,201 @@
|
|||||||
|
//! ConvNeXt implementation.
|
||||||
|
//!
|
||||||
|
//! See "A ConvNet for the 2020s" Liu et al. 2022
|
||||||
|
//! <https://arxiv.org/abs/2201.03545>
|
||||||
|
|
||||||
|
//! Original code: https://github.com/facebookresearch/ConvNeXt/
|
||||||
|
//! timm: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/convnext.py
|
||||||
|
|
||||||
|
use candle::{Result, D};
|
||||||
|
use candle_nn::{conv2d, layer_norm, linear, Conv2dConfig, Func, VarBuilder};
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct Config {
|
||||||
|
blocks: [usize; 4],
|
||||||
|
channels: [usize; 4],
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Config {
|
||||||
|
pub fn tiny() -> Self {
|
||||||
|
Self {
|
||||||
|
blocks: [3, 3, 9, 3],
|
||||||
|
channels: [96, 192, 384, 768],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pub fn small() -> Self {
|
||||||
|
Self {
|
||||||
|
blocks: [3, 3, 27, 3],
|
||||||
|
channels: [96, 192, 384, 768],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pub fn base() -> Self {
|
||||||
|
Self {
|
||||||
|
blocks: [3, 3, 27, 3],
|
||||||
|
channels: [128, 256, 512, 1024],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pub fn large() -> Self {
|
||||||
|
Self {
|
||||||
|
blocks: [3, 3, 27, 3],
|
||||||
|
channels: [192, 384, 768, 1536],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn xlarge() -> Self {
|
||||||
|
Self {
|
||||||
|
blocks: [3, 3, 27, 3],
|
||||||
|
channels: [256, 512, 1024, 2048],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initial downsampling via a patchify layer.
|
||||||
|
fn convnext_stem(out_channels: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||||
|
let conv2d_cfg = Conv2dConfig {
|
||||||
|
stride: 4,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let patchify = conv2d(3, out_channels, 4, conv2d_cfg, vb.pp(0))?;
|
||||||
|
let norm = layer_norm(out_channels, 1e-6, vb.pp(1))?;
|
||||||
|
Ok(Func::new(move |xs| {
|
||||||
|
// The layer norm works with channels-last format.
|
||||||
|
let xs = xs
|
||||||
|
.apply(&patchify)?
|
||||||
|
.permute((0, 2, 3, 1))?
|
||||||
|
.apply(&norm)?
|
||||||
|
.permute((0, 3, 1, 2))?;
|
||||||
|
Ok(xs)
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Downsampling applied after the stages.
|
||||||
|
fn convnext_downsample(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||||
|
let conv2d_cfg = Conv2dConfig {
|
||||||
|
stride: 2,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let norm = layer_norm(dim / 2, 1e-5, vb.pp(0))?;
|
||||||
|
let conv = conv2d(dim / 2, dim, 2, conv2d_cfg, vb.pp(1))?;
|
||||||
|
Ok(Func::new(move |xs| {
|
||||||
|
let xs = xs
|
||||||
|
.permute((0, 2, 3, 1))?
|
||||||
|
.apply(&norm)?
|
||||||
|
.permute((0, 3, 1, 2))?
|
||||||
|
.apply(&conv)?;
|
||||||
|
Ok(xs)
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
// MLP equivalent of pointwise convolutions.
|
||||||
|
fn convnext_mlp(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||||
|
let fc1 = linear(dim, 4 * dim, vb.pp("fc1"))?;
|
||||||
|
let fc2 = linear(4 * dim, dim, vb.pp("fc2"))?;
|
||||||
|
|
||||||
|
Ok(Func::new(move |xs| {
|
||||||
|
let xs = xs.apply(&fc1)?.gelu_erf()?.apply(&fc2)?;
|
||||||
|
Ok(xs)
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
// A block consisting of a depthwise convolution, a MLP and layer scaling.
|
||||||
|
fn convnext_block(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||||
|
let conv2d_cfg = Conv2dConfig {
|
||||||
|
groups: dim,
|
||||||
|
padding: 3,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
let conv_dw = conv2d(dim, dim, 7, conv2d_cfg, vb.pp("conv_dw"))?;
|
||||||
|
|
||||||
|
let gamma = vb.get(dim, "gamma")?;
|
||||||
|
let mlp = convnext_mlp(dim, vb.pp("mlp"))?;
|
||||||
|
let norm = layer_norm(dim, 1e-6, vb.pp("norm"))?;
|
||||||
|
|
||||||
|
Ok(Func::new(move |xs| {
|
||||||
|
let residual = xs;
|
||||||
|
let xs = xs
|
||||||
|
.apply(&conv_dw)?
|
||||||
|
.permute((0, 2, 3, 1))?
|
||||||
|
.apply(&norm)?
|
||||||
|
.apply(&mlp)?
|
||||||
|
.broadcast_mul(&gamma)?
|
||||||
|
.permute((0, 3, 1, 2))?;
|
||||||
|
|
||||||
|
xs + residual
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Each stage contains blocks and a downsampling layer for the previous stage.
|
||||||
|
fn convnext_stage(cfg: &Config, stage_idx: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||||
|
let nblocks = cfg.blocks[stage_idx];
|
||||||
|
let mut blocks = Vec::with_capacity(nblocks);
|
||||||
|
|
||||||
|
let dim = cfg.channels[stage_idx];
|
||||||
|
|
||||||
|
if stage_idx > 0 {
|
||||||
|
blocks.push(convnext_downsample(dim, vb.pp("downsample"))?);
|
||||||
|
}
|
||||||
|
|
||||||
|
for block_idx in 0..nblocks {
|
||||||
|
blocks.push(convnext_block(dim, vb.pp(format!("blocks.{block_idx}")))?);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(Func::new(move |xs| {
|
||||||
|
let mut xs = xs.clone();
|
||||||
|
for block in blocks.iter() {
|
||||||
|
xs = xs.apply(block)?
|
||||||
|
}
|
||||||
|
Ok(xs)
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn convnext_head(outputs: usize, nclasses: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||||
|
let norm = layer_norm(outputs, 1e-6, vb.pp("norm"))?;
|
||||||
|
let linear = linear(outputs, nclasses, vb.pp("fc"))?;
|
||||||
|
Ok(Func::new(move |xs| xs.apply(&norm)?.apply(&linear)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build a convnext model for a given configuration.
|
||||||
|
fn convnext_model(
|
||||||
|
config: &Config,
|
||||||
|
nclasses: Option<usize>,
|
||||||
|
vb: VarBuilder,
|
||||||
|
) -> Result<Func<'static>> {
|
||||||
|
let head = match nclasses {
|
||||||
|
None => None,
|
||||||
|
Some(nclasses) => {
|
||||||
|
let head = convnext_head(config.channels[3], nclasses, vb.pp("head"))?;
|
||||||
|
Some(head)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let stem = convnext_stem(config.channels[0], vb.pp("stem"))?;
|
||||||
|
let vb = vb.pp("stages");
|
||||||
|
let stage1 = convnext_stage(config, 0, vb.pp(0))?;
|
||||||
|
let stage2 = convnext_stage(config, 1, vb.pp(1))?;
|
||||||
|
let stage3 = convnext_stage(config, 2, vb.pp(2))?;
|
||||||
|
let stage4 = convnext_stage(config, 3, vb.pp(3))?;
|
||||||
|
|
||||||
|
Ok(Func::new(move |xs| {
|
||||||
|
let xs = xs
|
||||||
|
.apply(&stem)?
|
||||||
|
.apply(&stage1)?
|
||||||
|
.apply(&stage2)?
|
||||||
|
.apply(&stage3)?
|
||||||
|
.apply(&stage4)?
|
||||||
|
.mean(D::Minus2)?
|
||||||
|
.mean(D::Minus1)?;
|
||||||
|
match &head {
|
||||||
|
None => Ok(xs),
|
||||||
|
Some(head) => xs.apply(head),
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn convnext(cfg: &Config, nclasses: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||||
|
convnext_model(cfg, Some(nclasses), vb)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn convnext_no_final_layer(cfg: &Config, vb: VarBuilder) -> Result<Func<'static>> {
|
||||||
|
convnext_model(cfg, None, vb)
|
||||||
|
}
|
@ -3,6 +3,7 @@ pub mod bigcode;
|
|||||||
pub mod blip;
|
pub mod blip;
|
||||||
pub mod blip_text;
|
pub mod blip_text;
|
||||||
pub mod convmixer;
|
pub mod convmixer;
|
||||||
|
pub mod convnext;
|
||||||
pub mod dinov2;
|
pub mod dinov2;
|
||||||
pub mod distilbert;
|
pub mod distilbert;
|
||||||
pub mod efficientnet;
|
pub mod efficientnet;
|
||||||
|
Reference in New Issue
Block a user