Expose the fields from batch-norm. (#1176)

This commit is contained in:
Laurent Mazare
2023-10-25 15:35:32 +01:00
committed by GitHub
parent c698e17619
commit 0acd16751d

View File

@ -100,9 +100,19 @@ impl BatchNorm {
num_features,
})
}
}
impl BatchNorm {
pub fn running_mean(&self) -> &Tensor {
&self.running_mean
}
pub fn running_var(&self) -> &Tensor {
&self.running_var
}
pub fn weight_and_bias(&self) -> Option<(&Tensor, &Tensor)> {
self.weight_and_bias.as_ref().map(|v| (&v.0, &v.1))
}
pub fn forward_learning(&self, x: &Tensor) -> Result<Tensor> {
let x_dtype = x.dtype();
let internal_dtype = match x_dtype {