mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
Allow for different behavior between training and eval (#1213)
* Forward with training. * Do not use dropout on vgg evaluation.
This commit is contained in:
@ -125,3 +125,15 @@ impl<T: Fn(&Tensor) -> Result<Tensor>> Module for T {
|
||||
self(xs)
|
||||
}
|
||||
}
|
||||
|
||||
// A trait defining a module with forward method using a single tensor argument and a flag to
|
||||
// separate the training and evaluation behaviors.
|
||||
pub trait ModuleT {
|
||||
fn forward_t(&self, xs: &Tensor, train: bool) -> Result<Tensor>;
|
||||
}
|
||||
|
||||
impl<M: Module> ModuleT for M {
|
||||
fn forward_t(&self, xs: &Tensor, _train: bool) -> Result<Tensor> {
|
||||
self.forward(xs)
|
||||
}
|
||||
}
|
||||
|
@ -2271,6 +2271,11 @@ impl Tensor {
|
||||
m.forward(self)
|
||||
}
|
||||
|
||||
/// Run the `forward` method of `m` on `self`.
|
||||
pub fn apply_t<M: crate::ModuleT>(&self, m: &M, train: bool) -> Result<Self> {
|
||||
m.forward_t(self, train)
|
||||
}
|
||||
|
||||
pub(crate) fn storage(&self) -> std::sync::RwLockReadGuard<'_, Storage> {
|
||||
self.storage.read().unwrap()
|
||||
}
|
||||
|
Reference in New Issue
Block a user