mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Allow for different behavior between training and eval (#1213)
* Forward with training. * Do not use dropout on vgg evaluation.
This commit is contained in:
@ -2,8 +2,8 @@
|
||||
//!
|
||||
//! See Very Deep Convolutional Networks for Large-Scale Image Recognition
|
||||
//! <https://arxiv.org/abs/1409.1556>
|
||||
use candle::{Module, Result, Tensor};
|
||||
use candle_nn::{Func, VarBuilder};
|
||||
use candle::{ModuleT, Result, Tensor};
|
||||
use candle_nn::{FuncT, VarBuilder};
|
||||
|
||||
// Enum representing the different VGG models
|
||||
pub enum Models {
|
||||
@ -15,7 +15,7 @@ pub enum Models {
|
||||
// Struct representing a VGG model
|
||||
#[derive(Debug)]
|
||||
pub struct Vgg<'a> {
|
||||
blocks: Vec<Func<'a>>,
|
||||
blocks: Vec<FuncT<'a>>,
|
||||
}
|
||||
|
||||
// Struct representing the configuration for the pre-logit layer
|
||||
@ -39,11 +39,11 @@ impl<'a> Vgg<'a> {
|
||||
}
|
||||
|
||||
// Implementation of the forward pass for the VGG model
|
||||
impl Module for Vgg<'_> {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
impl ModuleT for Vgg<'_> {
|
||||
fn forward_t(&self, xs: &Tensor, train: bool) -> Result<Tensor> {
|
||||
let mut xs = xs.unsqueeze(0)?;
|
||||
for block in self.blocks.iter() {
|
||||
xs = xs.apply(block)?;
|
||||
xs = xs.apply_t(block, train)?;
|
||||
}
|
||||
Ok(xs)
|
||||
}
|
||||
@ -51,7 +51,7 @@ impl Module for Vgg<'_> {
|
||||
|
||||
// Function to create a conv2d block
|
||||
// The block is composed of two conv2d layers followed by a max pool layer
|
||||
fn conv2d_block(convs: &[(usize, usize, &str)], vb: &VarBuilder) -> Result<Func<'static>> {
|
||||
fn conv2d_block(convs: &[(usize, usize, &str)], vb: &VarBuilder) -> Result<FuncT<'static>> {
|
||||
let layers = convs
|
||||
.iter()
|
||||
.enumerate()
|
||||
@ -70,7 +70,7 @@ fn conv2d_block(convs: &[(usize, usize, &str)], vb: &VarBuilder) -> Result<Func<
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
|
||||
Ok(Func::new(move |xs| {
|
||||
Ok(FuncT::new(move |xs, _train| {
|
||||
let mut xs = xs.clone();
|
||||
for layer in layers.iter() {
|
||||
xs = xs.apply(layer)?.relu()?
|
||||
@ -87,7 +87,7 @@ fn fully_connected(
|
||||
pre_logit_1: PreLogitConfig,
|
||||
pre_logit_2: PreLogitConfig,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Func> {
|
||||
) -> Result<FuncT> {
|
||||
let lin = get_weights_and_biases(
|
||||
&vb.pp("pre_logits.fc1"),
|
||||
pre_logit_1.in_dim,
|
||||
@ -100,12 +100,15 @@ fn fully_connected(
|
||||
pre_logit_2.target_in,
|
||||
pre_logit_2.target_out,
|
||||
)?;
|
||||
Ok(Func::new(move |xs| {
|
||||
let dropout1 = candle_nn::Dropout::new(0.5);
|
||||
let dropout2 = candle_nn::Dropout::new(0.5);
|
||||
let dropout3 = candle_nn::Dropout::new(0.5);
|
||||
Ok(FuncT::new(move |xs, train| {
|
||||
let xs = xs.reshape((1, pre_logit_1.target_out))?;
|
||||
let xs = candle_nn::ops::dropout(&xs, 0.5)?.apply(&lin)?.relu()?;
|
||||
let xs = candle_nn::ops::dropout(&xs, 0.5)?.apply(&lin2)?.relu()?;
|
||||
let xs = xs.apply_t(&dropout1, train)?.apply(&lin)?.relu()?;
|
||||
let xs = xs.apply_t(&dropout2, train)?.apply(&lin2)?.relu()?;
|
||||
let lin3 = candle_nn::linear(4096, num_classes, vb.pp("head.fc"))?;
|
||||
let xs = candle_nn::ops::dropout(&xs, 0.5)?.apply(&lin3)?.relu()?;
|
||||
let xs = xs.apply_t(&dropout3, train)?.apply(&lin3)?.relu()?;
|
||||
Ok(xs)
|
||||
}))
|
||||
}
|
||||
@ -130,7 +133,7 @@ fn get_weights_and_biases(
|
||||
Ok(candle_nn::Linear::new(ws, Some(bs)))
|
||||
}
|
||||
|
||||
fn vgg13_blocks(vb: VarBuilder) -> Result<Vec<Func>> {
|
||||
fn vgg13_blocks(vb: VarBuilder) -> Result<Vec<FuncT>> {
|
||||
let num_classes = 1000;
|
||||
let blocks = vec![
|
||||
conv2d_block(&[(3, 64, "features.0"), (64, 64, "features.2")], &vb)?,
|
||||
@ -156,7 +159,7 @@ fn vgg13_blocks(vb: VarBuilder) -> Result<Vec<Func>> {
|
||||
Ok(blocks)
|
||||
}
|
||||
|
||||
fn vgg16_blocks(vb: VarBuilder) -> Result<Vec<Func>> {
|
||||
fn vgg16_blocks(vb: VarBuilder) -> Result<Vec<FuncT>> {
|
||||
let num_classes = 1000;
|
||||
let blocks = vec![
|
||||
conv2d_block(&[(3, 64, "features.0"), (64, 64, "features.2")], &vb)?,
|
||||
@ -203,7 +206,7 @@ fn vgg16_blocks(vb: VarBuilder) -> Result<Vec<Func>> {
|
||||
Ok(blocks)
|
||||
}
|
||||
|
||||
fn vgg19_blocks(vb: VarBuilder) -> Result<Vec<Func>> {
|
||||
fn vgg19_blocks(vb: VarBuilder) -> Result<Vec<FuncT>> {
|
||||
let num_classes = 1000;
|
||||
let blocks = vec![
|
||||
conv2d_block(&[(3, 64, "features.0"), (64, 64, "features.2")], &vb)?,
|
||||
|
Reference in New Issue
Block a user