Allow for different behavior between training and eval (#1213)

* Forward with training.

* Do not use dropout on vgg evaluation.
This commit is contained in:
Laurent Mazare
2023-10-29 07:53:09 +01:00
committed by GitHub
parent dece37c6f4
commit 55bc3382cf
8 changed files with 83 additions and 22 deletions

View File

@ -2,8 +2,8 @@
//!
//! See Very Deep Convolutional Networks for Large-Scale Image Recognition
//! <https://arxiv.org/abs/1409.1556>
use candle::{Module, Result, Tensor};
use candle_nn::{Func, VarBuilder};
use candle::{ModuleT, Result, Tensor};
use candle_nn::{FuncT, VarBuilder};
// Enum representing the different VGG models
pub enum Models {
@ -15,7 +15,7 @@ pub enum Models {
// Struct representing a VGG model
#[derive(Debug)]
pub struct Vgg<'a> {
blocks: Vec<Func<'a>>,
blocks: Vec<FuncT<'a>>,
}
// Struct representing the configuration for the pre-logit layer
@ -39,11 +39,11 @@ impl<'a> Vgg<'a> {
}
// Implementation of the forward pass for the VGG model
impl Module for Vgg<'_> {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
impl ModuleT for Vgg<'_> {
fn forward_t(&self, xs: &Tensor, train: bool) -> Result<Tensor> {
let mut xs = xs.unsqueeze(0)?;
for block in self.blocks.iter() {
xs = xs.apply(block)?;
xs = xs.apply_t(block, train)?;
}
Ok(xs)
}
@ -51,7 +51,7 @@ impl Module for Vgg<'_> {
// Function to create a conv2d block
// The block is composed of two conv2d layers followed by a max pool layer
fn conv2d_block(convs: &[(usize, usize, &str)], vb: &VarBuilder) -> Result<Func<'static>> {
fn conv2d_block(convs: &[(usize, usize, &str)], vb: &VarBuilder) -> Result<FuncT<'static>> {
let layers = convs
.iter()
.enumerate()
@ -70,7 +70,7 @@ fn conv2d_block(convs: &[(usize, usize, &str)], vb: &VarBuilder) -> Result<Func<
})
.collect::<Result<Vec<_>>>()?;
Ok(Func::new(move |xs| {
Ok(FuncT::new(move |xs, _train| {
let mut xs = xs.clone();
for layer in layers.iter() {
xs = xs.apply(layer)?.relu()?
@ -87,7 +87,7 @@ fn fully_connected(
pre_logit_1: PreLogitConfig,
pre_logit_2: PreLogitConfig,
vb: VarBuilder,
) -> Result<Func> {
) -> Result<FuncT> {
let lin = get_weights_and_biases(
&vb.pp("pre_logits.fc1"),
pre_logit_1.in_dim,
@ -100,12 +100,15 @@ fn fully_connected(
pre_logit_2.target_in,
pre_logit_2.target_out,
)?;
Ok(Func::new(move |xs| {
let dropout1 = candle_nn::Dropout::new(0.5);
let dropout2 = candle_nn::Dropout::new(0.5);
let dropout3 = candle_nn::Dropout::new(0.5);
Ok(FuncT::new(move |xs, train| {
let xs = xs.reshape((1, pre_logit_1.target_out))?;
let xs = candle_nn::ops::dropout(&xs, 0.5)?.apply(&lin)?.relu()?;
let xs = candle_nn::ops::dropout(&xs, 0.5)?.apply(&lin2)?.relu()?;
let xs = xs.apply_t(&dropout1, train)?.apply(&lin)?.relu()?;
let xs = xs.apply_t(&dropout2, train)?.apply(&lin2)?.relu()?;
let lin3 = candle_nn::linear(4096, num_classes, vb.pp("head.fc"))?;
let xs = candle_nn::ops::dropout(&xs, 0.5)?.apply(&lin3)?.relu()?;
let xs = xs.apply_t(&dropout3, train)?.apply(&lin3)?.relu()?;
Ok(xs)
}))
}
@ -130,7 +133,7 @@ fn get_weights_and_biases(
Ok(candle_nn::Linear::new(ws, Some(bs)))
}
fn vgg13_blocks(vb: VarBuilder) -> Result<Vec<Func>> {
fn vgg13_blocks(vb: VarBuilder) -> Result<Vec<FuncT>> {
let num_classes = 1000;
let blocks = vec![
conv2d_block(&[(3, 64, "features.0"), (64, 64, "features.2")], &vb)?,
@ -156,7 +159,7 @@ fn vgg13_blocks(vb: VarBuilder) -> Result<Vec<Func>> {
Ok(blocks)
}
fn vgg16_blocks(vb: VarBuilder) -> Result<Vec<Func>> {
fn vgg16_blocks(vb: VarBuilder) -> Result<Vec<FuncT>> {
let num_classes = 1000;
let blocks = vec![
conv2d_block(&[(3, 64, "features.0"), (64, 64, "features.2")], &vb)?,
@ -203,7 +206,7 @@ fn vgg16_blocks(vb: VarBuilder) -> Result<Vec<Func>> {
Ok(blocks)
}
fn vgg19_blocks(vb: VarBuilder) -> Result<Vec<Func>> {
fn vgg19_blocks(vb: VarBuilder) -> Result<Vec<FuncT>> {
let num_classes = 1000;
let blocks = vec![
conv2d_block(&[(3, 64, "features.0"), (64, 64, "features.2")], &vb)?,