mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Allow for different behavior between training and eval (#1213)
* Forward with training. * Do not use dropout on vgg evaluation.
This commit is contained in:
@ -36,3 +36,38 @@ impl<'a> Func<'a> {
|
||||
Self { f: Arc::new(f) }
|
||||
}
|
||||
}
|
||||
|
||||
/// A layer defined by a simple closure.
|
||||
#[derive(Clone)]
|
||||
pub struct FuncT<'a> {
|
||||
#[allow(clippy::type_complexity)]
|
||||
f: Arc<dyn 'a + Fn(&Tensor, bool) -> Result<Tensor> + Send + Sync>,
|
||||
}
|
||||
|
||||
impl<'a> std::fmt::Debug for FuncT<'a> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
write!(f, "func")
|
||||
}
|
||||
}
|
||||
|
||||
pub fn func_t<'a, F>(f: F) -> FuncT<'a>
|
||||
where
|
||||
F: 'a + Fn(&Tensor, bool) -> Result<Tensor> + Send + Sync,
|
||||
{
|
||||
FuncT { f: Arc::new(f) }
|
||||
}
|
||||
|
||||
impl<'a> super::ModuleT for FuncT<'a> {
|
||||
fn forward_t(&self, xs: &Tensor, train: bool) -> Result<Tensor> {
|
||||
(*self.f)(xs, train)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> FuncT<'a> {
|
||||
pub fn new<F>(f: F) -> Self
|
||||
where
|
||||
F: 'a + Fn(&Tensor, bool) -> Result<Tensor> + Send + Sync,
|
||||
{
|
||||
Self { f: Arc::new(f) }
|
||||
}
|
||||
}
|
||||
|
@ -22,7 +22,7 @@ pub use conv::{
|
||||
Conv1dConfig, Conv2d, Conv2dConfig, ConvTranspose2d, ConvTranspose2dConfig,
|
||||
};
|
||||
pub use embedding::{embedding, Embedding};
|
||||
pub use func::{func, Func};
|
||||
pub use func::{func, func_t, Func, FuncT};
|
||||
pub use group_norm::{group_norm, GroupNorm};
|
||||
pub use init::Init;
|
||||
pub use layer_norm::{layer_norm, rms_norm, LayerNorm, LayerNormConfig, RmsNorm};
|
||||
@ -34,4 +34,4 @@ pub use sequential::{seq, Sequential};
|
||||
pub use var_builder::VarBuilder;
|
||||
pub use var_map::VarMap;
|
||||
|
||||
pub use candle::Module;
|
||||
pub use candle::{Module, ModuleT};
|
||||
|
@ -84,6 +84,12 @@ impl Dropout {
|
||||
}
|
||||
}
|
||||
|
||||
impl candle::ModuleT for Dropout {
|
||||
fn forward_t(&self, xs: &Tensor, train: bool) -> Result<Tensor> {
|
||||
self.forward(xs, train)
|
||||
}
|
||||
}
|
||||
|
||||
struct SoftmaxLastDim;
|
||||
|
||||
impl candle::CustomOp1 for SoftmaxLastDim {
|
||||
|
Reference in New Issue
Block a user