From cd53c472df163b3baaf7c70ca5d4f8909af62324 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Thu, 19 Oct 2023 10:48:31 +0100 Subject: [PATCH] Support ResNet 50/101/152. (#1130) --- candle-transformers/src/models/resnet.rs | 118 +++++++++++++++++++++++ 1 file changed, 118 insertions(+) diff --git a/candle-transformers/src/models/resnet.rs b/candle-transformers/src/models/resnet.rs index dd6eee58..f2588e01 100644 --- a/candle-transformers/src/models/resnet.rs +++ b/candle-transformers/src/models/resnet.rs @@ -129,3 +129,121 @@ pub fn resnet34(num_classes: usize, vb: VarBuilder) -> Result { pub fn resnet34_no_final_layer(vb: VarBuilder) -> Result { resnet(None, 3, 4, 6, 3, vb) } + +// Bottleneck versions for ResNet 50, 101, and 152. +fn bottleneck_block( + c_in: usize, + c_out: usize, + stride: usize, + e: usize, + vb: VarBuilder, +) -> Result { + let e_dim = e * c_out; + let conv1 = conv2d(c_in, c_out, 1, 0, 1, vb.pp("conv1"))?; + let bn1 = batch_norm(c_out, 1e-5, vb.pp("bn1"))?; + let conv2 = conv2d(c_out, c_out, 3, 1, stride, vb.pp("conv2"))?; + let bn2 = batch_norm(c_out, 1e-5, vb.pp("bn2"))?; + let conv3 = conv2d(c_out, e_dim, 1, 0, 1, vb.pp("conv3"))?; + let bn3 = batch_norm(e_dim, 1e-5, vb.pp("bn3"))?; + let downsample = downsample(c_in, e_dim, stride, vb.pp("downsample"))?; + Ok(Func::new(move |xs| { + let ys = xs + .apply(&conv1)? + .apply(&bn1)? + .relu()? + .apply(&conv2)? + .apply(&bn2)? + .relu()? + .apply(&conv3)? + .apply(&bn3)?; + (xs.apply(&downsample)? + ys)?.relu() + })) +} + +fn bottleneck_layer( + c_in: usize, + c_out: usize, + stride: usize, + cnt: usize, + vb: VarBuilder, +) -> Result { + let mut layers = Vec::with_capacity(cnt); + for index in 0..cnt { + let l_in = if index == 0 { c_in } else { 4 * c_out }; + let stride = if index == 0 { stride } else { 1 }; + layers.push(bottleneck_block(l_in, c_out, stride, 4, vb.pp(index))?) + } + Ok(Func::new(move |xs| { + let mut xs = xs.clone(); + for layer in layers.iter() { + xs = xs.apply(layer)? + } + Ok(xs) + })) +} + +fn bottleneck_resnet( + nclasses: Option, + c1: usize, + c2: usize, + c3: usize, + c4: usize, + vb: VarBuilder, +) -> Result { + let conv1 = conv2d(3, 64, 7, 3, 2, vb.pp("conv1"))?; + let bn1 = batch_norm(64, 1e-5, vb.pp("bn1"))?; + let layer1 = bottleneck_layer(64, 64, 1, c1, vb.pp("layer1"))?; + let layer2 = bottleneck_layer(4 * 64, 128, 2, c2, vb.pp("layer2"))?; + let layer3 = bottleneck_layer(4 * 128, 256, 2, c3, vb.pp("layer3"))?; + let layer4 = bottleneck_layer(4 * 256, 512, 2, c4, vb.pp("layer4"))?; + let fc = match nclasses { + None => None, + Some(nclasses) => { + let linear = candle_nn::linear(4 * 512, nclasses, vb.pp("fc"))?; + Some(linear) + } + }; + Ok(Func::new(move |xs| { + let xs = xs + .apply(&conv1)? + .apply(&bn1)? + .relu()? + .pad_with_same(D::Minus1, 1, 1)? + .pad_with_same(D::Minus2, 1, 1)? + .max_pool2d_with_stride(3, 2)? + .apply(&layer1)? + .apply(&layer2)? + .apply(&layer3)? + .apply(&layer4)? + .mean(D::Minus1)? + .mean(D::Minus1)?; + match &fc { + None => Ok(xs), + Some(fc) => xs.apply(fc), + } + })) +} + +pub fn resnet50(num_classes: usize, vb: VarBuilder) -> Result { + bottleneck_resnet(Some(num_classes), 3, 4, 6, 3, vb) +} + +pub fn resnet50_no_final_layer(vb: VarBuilder) -> Result { + bottleneck_resnet(None, 3, 4, 6, 3, vb) +} + +pub fn resnet101(num_classes: usize, vb: VarBuilder) -> Result { + bottleneck_resnet(Some(num_classes), 3, 4, 23, 3, vb) +} + +pub fn resnet101_no_final_layer(vb: VarBuilder) -> Result { + bottleneck_resnet(None, 3, 4, 23, 3, vb) +} + +pub fn resnet152(num_classes: usize, vb: VarBuilder) -> Result { + bottleneck_resnet(Some(num_classes), 3, 8, 36, 3, vb) +} + +pub fn resnet152_no_final_layer(vb: VarBuilder) -> Result { + bottleneck_resnet(None, 3, 8, 36, 3, vb) +}