mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Allow for different behavior between training and eval (#1213)
* Forward with training. * Do not use dropout on vgg evaluation.
This commit is contained in:
@ -9,7 +9,7 @@ use clap::{Parser, ValueEnum};
|
||||
use rand::prelude::*;
|
||||
|
||||
use candle::{DType, Result, Tensor, D};
|
||||
use candle_nn::{loss, ops, Conv2d, Linear, Module, Optimizer, VarBuilder, VarMap};
|
||||
use candle_nn::{loss, ops, Conv2d, Linear, Module, ModuleT, Optimizer, VarBuilder, VarMap};
|
||||
|
||||
const IMAGE_DIM: usize = 784;
|
||||
const LABELS: usize = 10;
|
||||
@ -95,7 +95,7 @@ impl ConvNet {
|
||||
.flatten_from(1)?
|
||||
.apply(&self.fc1)?
|
||||
.relu()?;
|
||||
self.dropout.forward(&xs, train)?.apply(&self.fc2)
|
||||
self.dropout.forward_t(&xs, train)?.apply(&self.fc2)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -5,7 +5,7 @@ extern crate intel_mkl_src;
|
||||
extern crate accelerate_src;
|
||||
|
||||
use candle::{DType, IndexOp, D};
|
||||
use candle_nn::{Module, VarBuilder};
|
||||
use candle_nn::{ModuleT, VarBuilder};
|
||||
use candle_transformers::models::vgg::{Models, Vgg};
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
@ -53,7 +53,7 @@ pub fn main() -> anyhow::Result<()> {
|
||||
Which::Vgg16 => Vgg::new(vb, Models::Vgg16)?,
|
||||
Which::Vgg19 => Vgg::new(vb, Models::Vgg19)?,
|
||||
};
|
||||
let logits = model.forward(&image)?;
|
||||
let logits = model.forward_t(&image, /*train=*/ false)?;
|
||||
|
||||
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
|
||||
.i(0)?
|
||||
|
Reference in New Issue
Block a user