Allow for different behavior between training and eval (#1213)

* Forward with training.

* Do not use dropout on vgg evaluation.
This commit is contained in:
Laurent Mazare
2023-10-29 07:53:09 +01:00
committed by GitHub
parent dece37c6f4
commit 55bc3382cf
8 changed files with 83 additions and 22 deletions

View File

@ -125,3 +125,15 @@ impl<T: Fn(&Tensor) -> Result<Tensor>> Module for T {
self(xs)
}
}
// A trait defining a module with forward method using a single tensor argument and a flag to
// separate the training and evaluation behaviors.
pub trait ModuleT {
fn forward_t(&self, xs: &Tensor, train: bool) -> Result<Tensor>;
}
impl<M: Module> ModuleT for M {
fn forward_t(&self, xs: &Tensor, _train: bool) -> Result<Tensor> {
self.forward(xs)
}
}

View File

@ -2271,6 +2271,11 @@ impl Tensor {
m.forward(self)
}
/// Run the `forward` method of `m` on `self`.
pub fn apply_t<M: crate::ModuleT>(&self, m: &M, train: bool) -> Result<Self> {
m.forward_t(self, train)
}
pub(crate) fn storage(&self) -> std::sync::RwLockReadGuard<'_, Storage> {
self.storage.read().unwrap()
}