mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 04:00:28 +00:00
Do not implement Module for BatchNorm. (#1513)
This commit is contained in:
@ -7,7 +7,7 @@
|
||||
//! running stats.
|
||||
//!
|
||||
//! [`Batch Normalization`]: https://arxiv.org/abs/1502.03167
|
||||
use candle::{DType, Module, Result, Tensor, Var};
|
||||
use candle::{DType, Result, Tensor, Var};
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub struct BatchNormConfig {
|
||||
@ -192,7 +192,7 @@ impl BatchNorm {
|
||||
self.momentum
|
||||
}
|
||||
|
||||
pub fn forward_learning(&self, x: &Tensor) -> Result<Tensor> {
|
||||
pub fn forward_train(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let num_features = self.running_mean.as_tensor().dim(0)?;
|
||||
let x_dtype = x.dtype();
|
||||
let internal_dtype = match x_dtype {
|
||||
@ -252,17 +252,7 @@ impl BatchNorm {
|
||||
x.reshape(x_dims_post_transpose)?.transpose(0, 1)
|
||||
}
|
||||
|
||||
pub fn forward_t(&self, x: &Tensor, train: bool) -> Result<Tensor> {
|
||||
if train {
|
||||
self.forward_learning(x)
|
||||
} else {
|
||||
self.forward(x)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for BatchNorm {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
fn forward_eval(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let target_shape: Vec<usize> = x
|
||||
.dims()
|
||||
.iter()
|
||||
@ -288,6 +278,16 @@ impl Module for BatchNorm {
|
||||
}
|
||||
}
|
||||
|
||||
impl crate::ModuleT for BatchNorm {
|
||||
fn forward_t(&self, x: &Tensor, train: bool) -> Result<Tensor> {
|
||||
if train {
|
||||
self.forward_train(x)
|
||||
} else {
|
||||
self.forward_eval(x)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn batch_norm<C: Into<BatchNormConfig>>(
|
||||
num_features: usize,
|
||||
config: C,
|
||||
|
Reference in New Issue
Block a user