Allow for different behavior between training and eval (#1213)

* Forward with training.

* Do not use dropout on vgg evaluation.
This commit is contained in:
Laurent Mazare
2023-10-29 07:53:09 +01:00
committed by GitHub
parent dece37c6f4
commit 55bc3382cf
8 changed files with 83 additions and 22 deletions

View File

@ -5,7 +5,7 @@ extern crate intel_mkl_src;
extern crate accelerate_src;
use candle::{DType, IndexOp, D};
use candle_nn::{Module, VarBuilder};
use candle_nn::{ModuleT, VarBuilder};
use candle_transformers::models::vgg::{Models, Vgg};
use clap::{Parser, ValueEnum};
@ -53,7 +53,7 @@ pub fn main() -> anyhow::Result<()> {
Which::Vgg16 => Vgg::new(vb, Models::Vgg16)?,
Which::Vgg19 => Vgg::new(vb, Models::Vgg19)?,
};
let logits = model.forward(&image)?;
let logits = model.forward_t(&image, /*train=*/ false)?;
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
.i(0)?