mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Experiment with resnet (#1128)
* Add some preliminary support for resnet. * Add an actual resnet example.
This commit is contained in:
76
candle-examples/examples/resnet/main.rs
Normal file
76
candle-examples/examples/resnet/main.rs
Normal file
@ -0,0 +1,76 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use candle::{DType, IndexOp, D};
|
||||
use candle_nn::{Module, VarBuilder};
|
||||
use candle_transformers::models::resnet;
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||
enum Which {
|
||||
Resnet18,
|
||||
Resnet34,
|
||||
}
|
||||
|
||||
#[derive(Parser)]
|
||||
struct Args {
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
image: String,
|
||||
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Variant of the model to use.
|
||||
#[arg(value_enum, long, default_value_t = Which::Resnet18)]
|
||||
which: Which,
|
||||
}
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let model_file = match args.model {
|
||||
None => {
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.model("lmz/candle-resnet".into());
|
||||
let filename = match args.which {
|
||||
Which::Resnet18 => "resnet18.safetensors",
|
||||
Which::Resnet34 => "resnet34.safetensors",
|
||||
};
|
||||
api.get(filename)?
|
||||
}
|
||||
Some(model) => model.into(),
|
||||
};
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
|
||||
let class_count = candle_examples::imagenet::CLASS_COUNT as usize;
|
||||
let model = match args.which {
|
||||
Which::Resnet18 => resnet::resnet18(class_count, vb)?,
|
||||
Which::Resnet34 => resnet::resnet34(class_count, vb)?,
|
||||
};
|
||||
println!("model built");
|
||||
let logits = model.forward(&image.unsqueeze(0)?)?;
|
||||
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
|
||||
.i(0)?
|
||||
.to_vec1::<f32>()?;
|
||||
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
|
||||
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
|
||||
for &(category_idx, pr) in prs.iter().take(5) {
|
||||
println!(
|
||||
"{:24}: {:.2}%",
|
||||
candle_examples::imagenet::CLASSES[category_idx],
|
||||
100. * pr
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
@ -25,3 +25,12 @@ impl<'a> super::Module for Func<'a> {
|
||||
(*self.f)(xs)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Func<'a> {
|
||||
pub fn new<F>(f: F) -> Self
|
||||
where
|
||||
F: 'a + Fn(&Tensor) -> Result<Tensor> + Send,
|
||||
{
|
||||
Self { f: Box::new(f) }
|
||||
}
|
||||
}
|
||||
|
@ -14,6 +14,7 @@ pub mod quantized_mixformer;
|
||||
pub mod quantized_mpt;
|
||||
pub mod quantized_stable_lm;
|
||||
pub mod quantized_t5;
|
||||
pub mod resnet;
|
||||
pub mod segment_anything;
|
||||
pub mod stable_diffusion;
|
||||
pub mod stable_lm;
|
||||
|
131
candle-transformers/src/models/resnet.rs
Normal file
131
candle-transformers/src/models/resnet.rs
Normal file
@ -0,0 +1,131 @@
|
||||
//! ResNet implementation.
|
||||
//!
|
||||
//! See "Deep Residual Learning for Image Recognition" He et al. 2015
|
||||
//! <https://arxiv.org/abs/1512.03385>
|
||||
use candle::{Result, D};
|
||||
use candle_nn::{batch_norm, Conv2d, Func, VarBuilder};
|
||||
|
||||
fn conv2d(
|
||||
c_in: usize,
|
||||
c_out: usize,
|
||||
ksize: usize,
|
||||
padding: usize,
|
||||
stride: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Conv2d> {
|
||||
let conv2d_cfg = candle_nn::Conv2dConfig {
|
||||
stride,
|
||||
padding,
|
||||
..Default::default()
|
||||
};
|
||||
candle_nn::conv2d_no_bias(c_in, c_out, ksize, conv2d_cfg, vb)
|
||||
}
|
||||
|
||||
fn downsample(c_in: usize, c_out: usize, stride: usize, vb: VarBuilder) -> Result<Func> {
|
||||
if stride != 1 || c_in != c_out {
|
||||
let conv = conv2d(c_in, c_out, 1, 0, stride, vb.pp(0))?;
|
||||
let bn = batch_norm(c_out, 1e-5, vb.pp(1))?;
|
||||
Ok(Func::new(move |xs| xs.apply(&conv)?.apply(&bn)))
|
||||
} else {
|
||||
Ok(Func::new(|xs| Ok(xs.clone())))
|
||||
}
|
||||
}
|
||||
|
||||
fn basic_block(c_in: usize, c_out: usize, stride: usize, vb: VarBuilder) -> Result<Func> {
|
||||
let conv1 = conv2d(c_in, c_out, 3, 1, stride, vb.pp("conv1"))?;
|
||||
let bn1 = batch_norm(c_out, 1e-5, vb.pp("bn1"))?;
|
||||
let conv2 = conv2d(c_out, c_out, 3, 1, 1, vb.pp("conv2"))?;
|
||||
let bn2 = batch_norm(c_out, 1e-5, vb.pp("bn2"))?;
|
||||
let downsample = downsample(c_in, c_out, stride, vb.pp("downsample"))?;
|
||||
Ok(Func::new(move |xs| {
|
||||
let ys = xs
|
||||
.apply(&conv1)?
|
||||
.apply(&bn1)?
|
||||
.relu()?
|
||||
.apply(&conv2)?
|
||||
.apply(&bn2)?;
|
||||
(xs.apply(&downsample)? + ys)?.relu()
|
||||
}))
|
||||
}
|
||||
|
||||
fn basic_layer(
|
||||
c_in: usize,
|
||||
c_out: usize,
|
||||
stride: usize,
|
||||
cnt: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Func> {
|
||||
let mut layers = Vec::with_capacity(cnt);
|
||||
for index in 0..cnt {
|
||||
let l_in = if index == 0 { c_in } else { c_out };
|
||||
let stride = if index == 0 { stride } else { 1 };
|
||||
layers.push(basic_block(l_in, c_out, stride, vb.pp(index))?)
|
||||
}
|
||||
Ok(Func::new(move |xs| {
|
||||
let mut xs = xs.clone();
|
||||
for layer in layers.iter() {
|
||||
xs = xs.apply(layer)?
|
||||
}
|
||||
Ok(xs)
|
||||
}))
|
||||
}
|
||||
|
||||
fn resnet(
|
||||
nclasses: Option<usize>,
|
||||
c1: usize,
|
||||
c2: usize,
|
||||
c3: usize,
|
||||
c4: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Func> {
|
||||
let conv1 = conv2d(3, 64, 7, 3, 2, vb.pp("conv1"))?;
|
||||
let bn1 = batch_norm(64, 1e-5, vb.pp("bn1"))?;
|
||||
let layer1 = basic_layer(64, 64, 1, c1, vb.pp("layer1"))?;
|
||||
let layer2 = basic_layer(64, 128, 2, c2, vb.pp("layer2"))?;
|
||||
let layer3 = basic_layer(128, 256, 2, c3, vb.pp("layer3"))?;
|
||||
let layer4 = basic_layer(256, 512, 2, c4, vb.pp("layer4"))?;
|
||||
let fc = match nclasses {
|
||||
None => None,
|
||||
Some(nclasses) => {
|
||||
let linear = candle_nn::linear(512, nclasses, vb.pp("fc"))?;
|
||||
Some(linear)
|
||||
}
|
||||
};
|
||||
Ok(Func::new(move |xs| {
|
||||
let xs = xs
|
||||
.apply(&conv1)?
|
||||
.apply(&bn1)?
|
||||
.relu()?
|
||||
.pad_with_same(D::Minus1, 1, 1)?
|
||||
.pad_with_same(D::Minus2, 1, 1)?
|
||||
.max_pool2d_with_stride(3, 2)?
|
||||
.apply(&layer1)?
|
||||
.apply(&layer2)?
|
||||
.apply(&layer3)?
|
||||
.apply(&layer4)?
|
||||
.mean(D::Minus1)?
|
||||
.mean(D::Minus1)?;
|
||||
match &fc {
|
||||
None => Ok(xs),
|
||||
Some(fc) => xs.apply(fc),
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
/// Creates a ResNet-18 model.
|
||||
pub fn resnet18(num_classes: usize, vb: VarBuilder) -> Result<Func> {
|
||||
resnet(Some(num_classes), 2, 2, 2, 2, vb)
|
||||
}
|
||||
|
||||
pub fn resnet18_no_final_layer(vb: VarBuilder) -> Result<Func> {
|
||||
resnet(None, 2, 2, 2, 2, vb)
|
||||
}
|
||||
|
||||
/// Creates a ResNet-34 model.
|
||||
pub fn resnet34(num_classes: usize, vb: VarBuilder) -> Result<Func> {
|
||||
resnet(Some(num_classes), 3, 4, 6, 3, vb)
|
||||
}
|
||||
|
||||
pub fn resnet34_no_final_layer(vb: VarBuilder) -> Result<Func> {
|
||||
resnet(None, 3, 4, 6, 3, vb)
|
||||
}
|
Reference in New Issue
Block a user