Do not implement Module for BatchNorm. (#1513)

This commit is contained in:
Laurent Mazare
2024-01-01 10:13:13 +01:00
committed by GitHub
parent 1fb2dd905c
commit b0fe5e4453
9 changed files with 31 additions and 33 deletions

View File

@ -7,7 +7,7 @@
//! running stats.
//!
//! [`Batch Normalization`]: https://arxiv.org/abs/1502.03167
use candle::{DType, Module, Result, Tensor, Var};
use candle::{DType, Result, Tensor, Var};
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct BatchNormConfig {
@ -192,7 +192,7 @@ impl BatchNorm {
self.momentum
}
pub fn forward_learning(&self, x: &Tensor) -> Result<Tensor> {
pub fn forward_train(&self, x: &Tensor) -> Result<Tensor> {
let num_features = self.running_mean.as_tensor().dim(0)?;
let x_dtype = x.dtype();
let internal_dtype = match x_dtype {
@ -252,17 +252,7 @@ impl BatchNorm {
x.reshape(x_dims_post_transpose)?.transpose(0, 1)
}
pub fn forward_t(&self, x: &Tensor, train: bool) -> Result<Tensor> {
if train {
self.forward_learning(x)
} else {
self.forward(x)
}
}
}
impl Module for BatchNorm {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
fn forward_eval(&self, x: &Tensor) -> Result<Tensor> {
let target_shape: Vec<usize> = x
.dims()
.iter()
@ -288,6 +278,16 @@ impl Module for BatchNorm {
}
}
impl crate::ModuleT for BatchNorm {
fn forward_t(&self, x: &Tensor, train: bool) -> Result<Tensor> {
if train {
self.forward_train(x)
} else {
self.forward_eval(x)
}
}
}
pub fn batch_norm<C: Into<BatchNormConfig>>(
num_features: usize,
config: C,