From 0acd16751d6e0a501bba6c6285a18ccc40fad59b Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Wed, 25 Oct 2023 15:35:32 +0100 Subject: [PATCH] Expose the fields from batch-norm. (#1176) --- candle-nn/src/batch_norm.rs | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/candle-nn/src/batch_norm.rs b/candle-nn/src/batch_norm.rs index 27ef15f7..05904859 100644 --- a/candle-nn/src/batch_norm.rs +++ b/candle-nn/src/batch_norm.rs @@ -100,9 +100,19 @@ impl BatchNorm { num_features, }) } -} -impl BatchNorm { + pub fn running_mean(&self) -> &Tensor { + &self.running_mean + } + + pub fn running_var(&self) -> &Tensor { + &self.running_var + } + + pub fn weight_and_bias(&self) -> Option<(&Tensor, &Tensor)> { + self.weight_and_bias.as_ref().map(|v| (&v.0, &v.1)) + } + pub fn forward_learning(&self, x: &Tensor) -> Result { let x_dtype = x.dtype(); let internal_dtype = match x_dtype {