//! EfficientNet implementation. //! //! https://arxiv.org/abs/1905.11946 #[cfg(feature = "mkl")] extern crate intel_mkl_src; #[cfg(feature = "accelerate")] extern crate accelerate_src; use clap::{Parser, ValueEnum}; use candle::{DType, IndexOp, Result, Tensor, D}; use candle_nn as nn; use nn::{Module, VarBuilder}; // Based on the Python version from torchvision. // https://github.com/pytorch/vision/blob/0d75d9e5516f446c9c0ef93bd4ed9fea13992d06/torchvision/models/efficientnet.py#L47 #[derive(Debug, Clone, Copy)] pub struct MBConvConfig { expand_ratio: f64, kernel: usize, stride: usize, input_channels: usize, out_channels: usize, num_layers: usize, } fn make_divisible(v: f64, divisor: usize) -> usize { let min_value = divisor; let new_v = usize::max( min_value, (v + divisor as f64 * 0.5) as usize / divisor * divisor, ); if (new_v as f64) < 0.9 * v { new_v + divisor } else { new_v } } fn bneck_confs(width_mult: f64, depth_mult: f64) -> Vec { let bneck_conf = |e, k, s, i, o, n| { let input_channels = make_divisible(i as f64 * width_mult, 8); let out_channels = make_divisible(o as f64 * width_mult, 8); let num_layers = (n as f64 * depth_mult).ceil() as usize; MBConvConfig { expand_ratio: e, kernel: k, stride: s, input_channels, out_channels, num_layers, } }; vec![ bneck_conf(1., 3, 1, 32, 16, 1), bneck_conf(6., 3, 2, 16, 24, 2), bneck_conf(6., 5, 2, 24, 40, 2), bneck_conf(6., 3, 2, 40, 80, 3), bneck_conf(6., 5, 1, 80, 112, 3), bneck_conf(6., 5, 2, 112, 192, 4), bneck_conf(6., 3, 1, 192, 320, 1), ] } impl MBConvConfig { fn b0() -> Vec { bneck_confs(1.0, 1.0) } fn b1() -> Vec { bneck_confs(1.0, 1.1) } fn b2() -> Vec { bneck_confs(1.1, 1.2) } fn b3() -> Vec { bneck_confs(1.2, 1.4) } fn b4() -> Vec { bneck_confs(1.4, 1.8) } fn b5() -> Vec { bneck_confs(1.6, 2.2) } fn b6() -> Vec { bneck_confs(1.8, 2.6) } fn b7() -> Vec { bneck_confs(2.0, 3.1) } } /// Conv2D with same padding. #[derive(Debug)] struct Conv2DSame { conv2d: nn::Conv2d, s: usize, k: usize, } impl Conv2DSame { fn new( vb: VarBuilder, i: usize, o: usize, k: usize, stride: usize, groups: usize, bias: bool, ) -> Result { let conv_config = nn::Conv2dConfig { stride, groups, ..Default::default() }; let conv2d = if bias { nn::conv2d(i, o, k, conv_config, vb)? } else { nn::conv2d_no_bias(i, o, k, conv_config, vb)? }; Ok(Self { conv2d, s: stride, k, }) } } impl Module for Conv2DSame { fn forward(&self, xs: &Tensor) -> Result { let s = self.s; let k = self.k; let (_, _, ih, iw) = xs.dims4()?; let oh = (ih + s - 1) / s; let ow = (iw + s - 1) / s; let pad_h = usize::max((oh - 1) * s + k - ih, 0); let pad_w = usize::max((ow - 1) * s + k - iw, 0); if pad_h > 0 || pad_w > 0 { let xs = xs.pad_with_zeros(2, pad_h / 2, pad_h - pad_h / 2)?; let xs = xs.pad_with_zeros(3, pad_w / 2, pad_w - pad_w / 2)?; self.conv2d.forward(&xs) } else { self.conv2d.forward(xs) } } } #[derive(Debug)] struct ConvNormActivation { conv2d: Conv2DSame, bn2d: nn::BatchNorm, activation: bool, } impl ConvNormActivation { fn new( vb: VarBuilder, i: usize, o: usize, k: usize, stride: usize, groups: usize, ) -> Result { let conv2d = Conv2DSame::new(vb.pp("0"), i, o, k, stride, groups, false)?; let bn2d = nn::batch_norm(o, 1e-3, vb.pp("1"))?; Ok(Self { conv2d, bn2d, activation: true, }) } fn no_activation(self) -> Self { Self { activation: false, ..self } } } impl Module for ConvNormActivation { fn forward(&self, xs: &Tensor) -> Result { let xs = self.conv2d.forward(xs)?; let xs = self.bn2d.forward(&xs)?; if self.activation { swish(&xs) } else { Ok(xs) } } } #[derive(Debug)] struct SqueezeExcitation { fc1: Conv2DSame, fc2: Conv2DSame, } impl SqueezeExcitation { fn new(vb: VarBuilder, in_channels: usize, squeeze_channels: usize) -> Result { let fc1 = Conv2DSame::new(vb.pp("fc1"), in_channels, squeeze_channels, 1, 1, 1, true)?; let fc2 = Conv2DSame::new(vb.pp("fc2"), squeeze_channels, in_channels, 1, 1, 1, true)?; Ok(Self { fc1, fc2 }) } } impl Module for SqueezeExcitation { fn forward(&self, xs: &Tensor) -> Result { let residual = xs; // equivalent to adaptive_avg_pool2d([1, 1]) let xs = xs.mean_keepdim(D::Minus2)?.mean_keepdim(D::Minus1)?; let xs = self.fc1.forward(&xs)?; let xs = swish(&xs)?; let xs = self.fc2.forward(&xs)?; let xs = nn::ops::sigmoid(&xs)?; residual.broadcast_mul(&xs) } } #[derive(Debug)] struct MBConv { expand_cna: Option, depthwise_cna: ConvNormActivation, squeeze_excitation: SqueezeExcitation, project_cna: ConvNormActivation, config: MBConvConfig, } impl MBConv { fn new(vb: VarBuilder, c: MBConvConfig) -> Result { let vb = vb.pp("block"); let exp = make_divisible(c.input_channels as f64 * c.expand_ratio, 8); let expand_cna = if exp != c.input_channels { Some(ConvNormActivation::new( vb.pp("0"), c.input_channels, exp, 1, 1, 1, )?) } else { None }; let start_index = if expand_cna.is_some() { 1 } else { 0 }; let depthwise_cna = ConvNormActivation::new(vb.pp(start_index), exp, exp, c.kernel, c.stride, exp)?; let squeeze_channels = usize::max(1, c.input_channels / 4); let squeeze_excitation = SqueezeExcitation::new(vb.pp(start_index + 1), exp, squeeze_channels)?; let project_cna = ConvNormActivation::new(vb.pp(start_index + 2), exp, c.out_channels, 1, 1, 1)? .no_activation(); Ok(Self { expand_cna, depthwise_cna, squeeze_excitation, project_cna, config: c, }) } } impl Module for MBConv { fn forward(&self, xs: &Tensor) -> Result { let use_res_connect = self.config.stride == 1 && self.config.input_channels == self.config.out_channels; let ys = match &self.expand_cna { Some(expand_cna) => expand_cna.forward(xs)?, None => xs.clone(), }; let ys = self.depthwise_cna.forward(&ys)?; let ys = self.squeeze_excitation.forward(&ys)?; let ys = self.project_cna.forward(&ys)?; if use_res_connect { ys + xs } else { Ok(ys) } } } fn swish(s: &Tensor) -> Result { s * nn::ops::sigmoid(s)? } #[derive(Debug)] struct EfficientNet { init_cna: ConvNormActivation, blocks: Vec, final_cna: ConvNormActivation, classifier: nn::Linear, } impl EfficientNet { fn new(p: VarBuilder, configs: Vec, nclasses: usize) -> Result { let f_p = p.pp("features"); let first_in_c = configs[0].input_channels; let last_out_c = configs.last().unwrap().out_channels; let final_out_c = 4 * last_out_c; let init_cna = ConvNormActivation::new(f_p.pp(0), 3, first_in_c, 3, 2, 1)?; let nconfigs = configs.len(); let mut blocks = vec![]; for (index, cnf) in configs.into_iter().enumerate() { let f_p = f_p.pp(index + 1); for r_index in 0..cnf.num_layers { let cnf = if r_index == 0 { cnf } else { MBConvConfig { input_channels: cnf.out_channels, stride: 1, ..cnf } }; blocks.push(MBConv::new(f_p.pp(r_index), cnf)?) } } let final_cna = ConvNormActivation::new(f_p.pp(nconfigs + 1), last_out_c, final_out_c, 1, 1, 1)?; let classifier = nn::linear(final_out_c, nclasses, p.pp("classifier.1"))?; Ok(Self { init_cna, blocks, final_cna, classifier, }) } } impl Module for EfficientNet { fn forward(&self, xs: &Tensor) -> Result { let mut xs = self.init_cna.forward(xs)?; for block in self.blocks.iter() { xs = block.forward(&xs)? } let xs = self.final_cna.forward(&xs)?; // Equivalent to adaptive_avg_pool2d([1, 1]) -> squeeze(-1) -> squeeze(-1) let xs = xs.mean(D::Minus1)?.mean(D::Minus1)?; self.classifier.forward(&xs) } } #[derive(Clone, Copy, Debug, ValueEnum)] enum Which { B0, B1, B2, B3, B4, B5, B6, B7, } #[derive(Parser)] struct Args { #[arg(long)] model: Option, #[arg(long)] image: String, /// Run on CPU rather than on GPU. #[arg(long)] cpu: bool, /// Variant of the model to use. #[arg(value_enum, long, default_value_t = Which::B2)] which: Which, } pub fn main() -> anyhow::Result<()> { let args = Args::parse(); let device = candle_examples::device(args.cpu)?; let image = candle_examples::imagenet::load_image224(args.image)?; println!("loaded image {image:?}"); let model_file = match args.model { None => { let api = hf_hub::api::sync::Api::new()?; let api = api.model("lmz/candle-efficientnet".into()); let filename = match args.which { Which::B0 => "efficientnet-b0.safetensors", Which::B1 => "efficientnet-b1.safetensors", Which::B2 => "efficientnet-b2.safetensors", Which::B3 => "efficientnet-b3.safetensors", Which::B4 => "efficientnet-b4.safetensors", Which::B5 => "efficientnet-b5.safetensors", Which::B6 => "efficientnet-b6.safetensors", Which::B7 => "efficientnet-b7.safetensors", }; api.get(filename)? } Some(model) => model.into(), }; let weights = unsafe { candle::safetensors::MmapedFile::new(model_file)? }; let weights = weights.deserialize()?; let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device); let cfg = match args.which { Which::B0 => MBConvConfig::b0(), Which::B1 => MBConvConfig::b1(), Which::B2 => MBConvConfig::b2(), Which::B3 => MBConvConfig::b3(), Which::B4 => MBConvConfig::b4(), Which::B5 => MBConvConfig::b5(), Which::B6 => MBConvConfig::b6(), Which::B7 => MBConvConfig::b7(), }; let model = EfficientNet::new(vb, cfg, candle_examples::imagenet::CLASS_COUNT as usize)?; println!("model built"); let logits = model.forward(&image.unsqueeze(0)?)?; let prs = candle_nn::ops::softmax(&logits, D::Minus1)? .i(0)? .to_vec1::()?; let mut prs = prs.iter().enumerate().collect::>(); prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1)); for &(category_idx, pr) in prs.iter().take(5) { println!( "{:24}: {:.2}%", candle_examples::imagenet::CLASSES[category_idx], 100. * pr ); } Ok(()) }