diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs index f2aed1b6..fcc17afc 100644 --- a/candle-core/src/lib.rs +++ b/candle-core/src/lib.rs @@ -129,6 +129,15 @@ impl Result> Module for T { } } +impl Module for Option<&M> { + fn forward(&self, xs: &Tensor) -> Result { + match self { + None => Ok(xs.clone()), + Some(m) => m.forward(xs), + } + } +} + // A trait defining a module with forward method using a single tensor argument and a flag to // separate the training and evaluation behaviors. pub trait ModuleT {