diff --git a/candle-examples/examples/resnet/main.rs b/candle-examples/examples/resnet/main.rs new file mode 100644 index 00000000..3badc48e --- /dev/null +++ b/candle-examples/examples/resnet/main.rs @@ -0,0 +1,76 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use candle::{DType, IndexOp, D}; +use candle_nn::{Module, VarBuilder}; +use candle_transformers::models::resnet; +use clap::{Parser, ValueEnum}; + +#[derive(Clone, Copy, Debug, ValueEnum)] +enum Which { + Resnet18, + Resnet34, +} + +#[derive(Parser)] +struct Args { + #[arg(long)] + model: Option, + + #[arg(long)] + image: String, + + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, + + /// Variant of the model to use. + #[arg(value_enum, long, default_value_t = Which::Resnet18)] + which: Which, +} + +pub fn main() -> anyhow::Result<()> { + let args = Args::parse(); + + let device = candle_examples::device(args.cpu)?; + + let image = candle_examples::imagenet::load_image224(args.image)?; + println!("loaded image {image:?}"); + + let model_file = match args.model { + None => { + let api = hf_hub::api::sync::Api::new()?; + let api = api.model("lmz/candle-resnet".into()); + let filename = match args.which { + Which::Resnet18 => "resnet18.safetensors", + Which::Resnet34 => "resnet34.safetensors", + }; + api.get(filename)? + } + Some(model) => model.into(), + }; + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? }; + let class_count = candle_examples::imagenet::CLASS_COUNT as usize; + let model = match args.which { + Which::Resnet18 => resnet::resnet18(class_count, vb)?, + Which::Resnet34 => resnet::resnet34(class_count, vb)?, + }; + println!("model built"); + let logits = model.forward(&image.unsqueeze(0)?)?; + let prs = candle_nn::ops::softmax(&logits, D::Minus1)? + .i(0)? + .to_vec1::()?; + let mut prs = prs.iter().enumerate().collect::>(); + prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1)); + for &(category_idx, pr) in prs.iter().take(5) { + println!( + "{:24}: {:.2}%", + candle_examples::imagenet::CLASSES[category_idx], + 100. * pr + ); + } + Ok(()) +} diff --git a/candle-nn/src/func.rs b/candle-nn/src/func.rs index c9ec287a..e7fd73ae 100644 --- a/candle-nn/src/func.rs +++ b/candle-nn/src/func.rs @@ -25,3 +25,12 @@ impl<'a> super::Module for Func<'a> { (*self.f)(xs) } } + +impl<'a> Func<'a> { + pub fn new(f: F) -> Self + where + F: 'a + Fn(&Tensor) -> Result + Send, + { + Self { f: Box::new(f) } + } +} diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index fc57e732..3b4ef7e1 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -14,6 +14,7 @@ pub mod quantized_mixformer; pub mod quantized_mpt; pub mod quantized_stable_lm; pub mod quantized_t5; +pub mod resnet; pub mod segment_anything; pub mod stable_diffusion; pub mod stable_lm; diff --git a/candle-transformers/src/models/resnet.rs b/candle-transformers/src/models/resnet.rs new file mode 100644 index 00000000..dd6eee58 --- /dev/null +++ b/candle-transformers/src/models/resnet.rs @@ -0,0 +1,131 @@ +//! ResNet implementation. +//! +//! See "Deep Residual Learning for Image Recognition" He et al. 2015 +//! +use candle::{Result, D}; +use candle_nn::{batch_norm, Conv2d, Func, VarBuilder}; + +fn conv2d( + c_in: usize, + c_out: usize, + ksize: usize, + padding: usize, + stride: usize, + vb: VarBuilder, +) -> Result { + let conv2d_cfg = candle_nn::Conv2dConfig { + stride, + padding, + ..Default::default() + }; + candle_nn::conv2d_no_bias(c_in, c_out, ksize, conv2d_cfg, vb) +} + +fn downsample(c_in: usize, c_out: usize, stride: usize, vb: VarBuilder) -> Result { + if stride != 1 || c_in != c_out { + let conv = conv2d(c_in, c_out, 1, 0, stride, vb.pp(0))?; + let bn = batch_norm(c_out, 1e-5, vb.pp(1))?; + Ok(Func::new(move |xs| xs.apply(&conv)?.apply(&bn))) + } else { + Ok(Func::new(|xs| Ok(xs.clone()))) + } +} + +fn basic_block(c_in: usize, c_out: usize, stride: usize, vb: VarBuilder) -> Result { + let conv1 = conv2d(c_in, c_out, 3, 1, stride, vb.pp("conv1"))?; + let bn1 = batch_norm(c_out, 1e-5, vb.pp("bn1"))?; + let conv2 = conv2d(c_out, c_out, 3, 1, 1, vb.pp("conv2"))?; + let bn2 = batch_norm(c_out, 1e-5, vb.pp("bn2"))?; + let downsample = downsample(c_in, c_out, stride, vb.pp("downsample"))?; + Ok(Func::new(move |xs| { + let ys = xs + .apply(&conv1)? + .apply(&bn1)? + .relu()? + .apply(&conv2)? + .apply(&bn2)?; + (xs.apply(&downsample)? + ys)?.relu() + })) +} + +fn basic_layer( + c_in: usize, + c_out: usize, + stride: usize, + cnt: usize, + vb: VarBuilder, +) -> Result { + let mut layers = Vec::with_capacity(cnt); + for index in 0..cnt { + let l_in = if index == 0 { c_in } else { c_out }; + let stride = if index == 0 { stride } else { 1 }; + layers.push(basic_block(l_in, c_out, stride, vb.pp(index))?) + } + Ok(Func::new(move |xs| { + let mut xs = xs.clone(); + for layer in layers.iter() { + xs = xs.apply(layer)? + } + Ok(xs) + })) +} + +fn resnet( + nclasses: Option, + c1: usize, + c2: usize, + c3: usize, + c4: usize, + vb: VarBuilder, +) -> Result { + let conv1 = conv2d(3, 64, 7, 3, 2, vb.pp("conv1"))?; + let bn1 = batch_norm(64, 1e-5, vb.pp("bn1"))?; + let layer1 = basic_layer(64, 64, 1, c1, vb.pp("layer1"))?; + let layer2 = basic_layer(64, 128, 2, c2, vb.pp("layer2"))?; + let layer3 = basic_layer(128, 256, 2, c3, vb.pp("layer3"))?; + let layer4 = basic_layer(256, 512, 2, c4, vb.pp("layer4"))?; + let fc = match nclasses { + None => None, + Some(nclasses) => { + let linear = candle_nn::linear(512, nclasses, vb.pp("fc"))?; + Some(linear) + } + }; + Ok(Func::new(move |xs| { + let xs = xs + .apply(&conv1)? + .apply(&bn1)? + .relu()? + .pad_with_same(D::Minus1, 1, 1)? + .pad_with_same(D::Minus2, 1, 1)? + .max_pool2d_with_stride(3, 2)? + .apply(&layer1)? + .apply(&layer2)? + .apply(&layer3)? + .apply(&layer4)? + .mean(D::Minus1)? + .mean(D::Minus1)?; + match &fc { + None => Ok(xs), + Some(fc) => xs.apply(fc), + } + })) +} + +/// Creates a ResNet-18 model. +pub fn resnet18(num_classes: usize, vb: VarBuilder) -> Result { + resnet(Some(num_classes), 2, 2, 2, 2, vb) +} + +pub fn resnet18_no_final_layer(vb: VarBuilder) -> Result { + resnet(None, 2, 2, 2, 2, vb) +} + +/// Creates a ResNet-34 model. +pub fn resnet34(num_classes: usize, vb: VarBuilder) -> Result { + resnet(Some(num_classes), 3, 4, 6, 3, vb) +} + +pub fn resnet34_no_final_layer(vb: VarBuilder) -> Result { + resnet(None, 3, 4, 6, 3, vb) +}