mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 04:00:28 +00:00
Allow for different behavior between training and eval (#1213)
* Forward with training. * Do not use dropout on vgg evaluation.
This commit is contained in:
@ -36,3 +36,38 @@ impl<'a> Func<'a> {
|
||||
Self { f: Arc::new(f) }
|
||||
}
|
||||
}
|
||||
|
||||
/// A layer defined by a simple closure.
|
||||
#[derive(Clone)]
|
||||
pub struct FuncT<'a> {
|
||||
#[allow(clippy::type_complexity)]
|
||||
f: Arc<dyn 'a + Fn(&Tensor, bool) -> Result<Tensor> + Send + Sync>,
|
||||
}
|
||||
|
||||
impl<'a> std::fmt::Debug for FuncT<'a> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
write!(f, "func")
|
||||
}
|
||||
}
|
||||
|
||||
pub fn func_t<'a, F>(f: F) -> FuncT<'a>
|
||||
where
|
||||
F: 'a + Fn(&Tensor, bool) -> Result<Tensor> + Send + Sync,
|
||||
{
|
||||
FuncT { f: Arc::new(f) }
|
||||
}
|
||||
|
||||
impl<'a> super::ModuleT for FuncT<'a> {
|
||||
fn forward_t(&self, xs: &Tensor, train: bool) -> Result<Tensor> {
|
||||
(*self.f)(xs, train)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> FuncT<'a> {
|
||||
pub fn new<F>(f: F) -> Self
|
||||
where
|
||||
F: 'a + Fn(&Tensor, bool) -> Result<Tensor> + Send + Sync,
|
||||
{
|
||||
Self { f: Arc::new(f) }
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user