[Breaking] Add training to batchnorm with exponential moving average (#1504)

* Add training to batchnorm with exponential moving average

* Add more checks to batch norm

* Resolve some review comments

* Add with_momentum varients of `new` methods

* Add check for range of momentum variable; update batch norm test

* Run cargo fmt

* Add back num_features parameter

* Format; tiny simplification
This commit is contained in:
nkoppel
2023-12-30 15:42:08 +00:00
committed by GitHub
parent 51e577a682
commit 4290b81244
2 changed files with 169 additions and 50 deletions

View File

@ -7,15 +7,22 @@
//! running stats. //! running stats.
//! //!
//! [`Batch Normalization`]: https://arxiv.org/abs/1502.03167 //! [`Batch Normalization`]: https://arxiv.org/abs/1502.03167
use candle::{DType, Result, Tensor}; use crate::Init;
use candle::{DType, Module, Result, Tensor, Var};
#[derive(Debug, Clone, Copy, PartialEq)] #[derive(Debug, Clone, Copy, PartialEq)]
pub struct BatchNormConfig { pub struct BatchNormConfig {
pub eps: f64, pub eps: f64,
pub remove_mean: bool, pub remove_mean: bool,
/// The meaning of affine here is different from LayerNorm: when false there is no learnable /// The meaning of affine here is different from LayerNorm: when false there is no learnable
/// parameter at all, 1 used for gamma and 0 for beta. /// parameter at all, 1 used for gamma and 0 for beta.
pub affine: bool, pub affine: bool,
/// Controls exponential moving average of running stats. Defaults to 0.1
///
/// `running_stat * (1.0 - momentum) + stat * momentum`.
pub momentum: f64,
} }
impl Default for BatchNormConfig { impl Default for BatchNormConfig {
@ -24,6 +31,7 @@ impl Default for BatchNormConfig {
eps: 1e-5, eps: 1e-5,
remove_mean: true, remove_mean: true,
affine: true, affine: true,
momentum: 0.1,
} }
} }
} }
@ -32,23 +40,62 @@ impl From<f64> for BatchNormConfig {
fn from(eps: f64) -> Self { fn from(eps: f64) -> Self {
Self { Self {
eps, eps,
remove_mean: true, ..Default::default()
affine: true,
} }
} }
} }
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct BatchNorm { pub struct BatchNorm {
running_mean: Tensor, running_mean: Var,
running_var: Tensor, running_var: Var,
weight_and_bias: Option<(Tensor, Tensor)>, weight_and_bias: Option<(Tensor, Tensor)>,
remove_mean: bool, remove_mean: bool,
eps: f64, eps: f64,
num_features: usize, momentum: f64,
} }
impl BatchNorm { impl BatchNorm {
fn check_validity(&self, num_features: usize) -> Result<()> {
if self.eps < 0. {
candle::bail!("batch-norm eps cannot be negative {}", self.eps)
}
if !(0.0..=1.0).contains(&self.momentum) {
candle::bail!(
"batch-norm momentum must be between 0 and 1, is {}",
self.momentum
)
}
if self.running_mean.dims() != [num_features] {
candle::bail!(
"batch-norm running mean has unexpected shape {:?} should have shape [{num_features}]",
self.running_mean.shape(),
)
}
if self.running_var.dims() != [num_features] {
candle::bail!(
"batch-norm running variance has unexpected shape {:?} should have shape [{num_features}]",
self.running_var.shape(),
)
}
if let Some((ref weight, ref bias)) = self.weight_and_bias.as_ref() {
if weight.dims() != [num_features] {
candle::bail!(
"batch-norm weight has unexpected shape {:?} should have shape [{num_features}]",
weight.shape(),
)
}
if bias.dims() != [num_features] {
candle::bail!(
"batch-norm weight has unexpected shape {:?} should have shape [{num_features}]",
bias.shape(),
)
}
}
Ok(())
}
pub fn new( pub fn new(
num_features: usize, num_features: usize,
running_mean: Tensor, running_mean: Tensor,
@ -57,29 +104,16 @@ impl BatchNorm {
bias: Tensor, bias: Tensor,
eps: f64, eps: f64,
) -> Result<Self> { ) -> Result<Self> {
if eps < 0. { let out = Self {
candle::bail!("batch-norm eps cannot be negative {eps}") running_mean: Var::from_tensor(&running_mean)?,
} running_var: Var::from_tensor(&running_var)?,
if weight.dims() != [num_features] {
candle::bail!(
"batch-norm unexpected weight shape {:?} {num_features}",
weight.shape()
)
}
if bias.dims() != [num_features] {
candle::bail!(
"batch-norm unexpected bias shape {:?} {num_features}",
bias.shape()
)
}
Ok(Self {
running_mean,
running_var,
weight_and_bias: Some((weight, bias)), weight_and_bias: Some((weight, bias)),
remove_mean: true, remove_mean: true,
eps, eps,
num_features, momentum: 0.1,
}) };
out.check_validity(num_features)?;
Ok(out)
} }
pub fn new_no_bias( pub fn new_no_bias(
@ -88,25 +122,64 @@ impl BatchNorm {
running_var: Tensor, running_var: Tensor,
eps: f64, eps: f64,
) -> Result<Self> { ) -> Result<Self> {
if eps < 0. { let out = Self {
candle::bail!("batch-norm eps cannot be negative {eps}") running_mean: Var::from_tensor(&running_mean)?,
} running_var: Var::from_tensor(&running_var)?,
Ok(Self {
running_mean,
running_var,
weight_and_bias: None, weight_and_bias: None,
remove_mean: true, remove_mean: true,
eps, eps,
num_features, momentum: 0.1,
}) };
out.check_validity(num_features)?;
Ok(out)
}
pub fn new_with_momentum(
num_features: usize,
running_mean: Tensor,
running_var: Tensor,
weight: Tensor,
bias: Tensor,
eps: f64,
momentum: f64,
) -> Result<Self> {
let out = Self {
running_mean: Var::from_tensor(&running_mean)?,
running_var: Var::from_tensor(&running_var)?,
weight_and_bias: Some((weight, bias)),
remove_mean: true,
eps,
momentum,
};
out.check_validity(num_features)?;
Ok(out)
}
pub fn new_no_bias_with_momentum(
num_features: usize,
running_mean: Tensor,
running_var: Tensor,
eps: f64,
momentum: f64,
) -> Result<Self> {
let out = Self {
running_mean: Var::from_tensor(&running_mean)?,
running_var: Var::from_tensor(&running_var)?,
weight_and_bias: None,
remove_mean: true,
eps,
momentum,
};
out.check_validity(num_features)?;
Ok(out)
} }
pub fn running_mean(&self) -> &Tensor { pub fn running_mean(&self) -> &Tensor {
&self.running_mean self.running_mean.as_tensor()
} }
pub fn running_var(&self) -> &Tensor { pub fn running_var(&self) -> &Tensor {
&self.running_var self.running_var.as_tensor()
} }
pub fn eps(&self) -> f64 { pub fn eps(&self) -> f64 {
@ -117,7 +190,12 @@ impl BatchNorm {
self.weight_and_bias.as_ref().map(|v| (&v.0, &v.1)) self.weight_and_bias.as_ref().map(|v| (&v.0, &v.1))
} }
pub fn momentum(&self) -> f64 {
self.momentum
}
pub fn forward_learning(&self, x: &Tensor) -> Result<Tensor> { pub fn forward_learning(&self, x: &Tensor) -> Result<Tensor> {
let num_features = self.running_mean.as_tensor().dim(0)?;
let x_dtype = x.dtype(); let x_dtype = x.dtype();
let internal_dtype = match x_dtype { let internal_dtype = match x_dtype {
DType::F16 | DType::BF16 => DType::F32, DType::F16 | DType::BF16 => DType::F32,
@ -129,11 +207,11 @@ impl BatchNorm {
x.shape() x.shape()
) )
} }
if x.dim(1)? != self.num_features { if x.dim(1)? != num_features {
candle::bail!( candle::bail!(
"batch-norm input doesn't have the expected number of features ({:?} <> {})", "batch-norm input doesn't have the expected number of features ({:?} <> {})",
x.shape(), x.shape(),
self.num_features num_features
) )
} }
let x = x.to_dtype(internal_dtype)?; let x = x.to_dtype(internal_dtype)?;
@ -142,26 +220,52 @@ impl BatchNorm {
let x = x.flatten_from(1)?.contiguous()?; let x = x.flatten_from(1)?.contiguous()?;
let x = if self.remove_mean { let x = if self.remove_mean {
let mean_x = x.mean_keepdim(1)?; let mean_x = x.mean_keepdim(1)?;
{
// Update running mean
let new_mean = ((self.running_mean.as_tensor() * (1.0 - self.momentum))?
+ (mean_x.flatten_all()? * self.momentum)?)?;
self.running_mean.set(&new_mean)?;
}
x.broadcast_sub(&mean_x)? x.broadcast_sub(&mean_x)?
} else { } else {
x x
}; };
let norm_x = x.sqr()?.mean_keepdim(1)?; let norm_x = x.sqr()?.mean_keepdim(1)?;
{
// Update running variance
let batch_size = x.dim(1)? as f64;
let running_var_weight = 1.0 - self.momentum;
let norm_x_weight = self.momentum * batch_size / (batch_size - 1.0);
let new_var = ((self.running_var.as_tensor() * running_var_weight)?
+ (&norm_x.flatten_all()? * norm_x_weight)?)?;
self.running_var.set(&new_var)?;
}
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?; let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
let x = x_normed.to_dtype(x_dtype)?; let x = x_normed.to_dtype(x_dtype)?;
let x = match &self.weight_and_bias { let x = match &self.weight_and_bias {
None => x, None => x,
Some((weight, bias)) => { Some((weight, bias)) => {
let weight = weight.reshape((self.num_features, 1))?; let weight = weight.reshape(((), 1))?;
let bias = bias.reshape((self.num_features, 1))?; let bias = bias.reshape(((), 1))?;
x.broadcast_mul(&weight)?.broadcast_add(&bias)? x.broadcast_mul(&weight)?.broadcast_add(&bias)?
} }
}; };
x.reshape(x_dims_post_transpose)?.transpose(0, 1) x.reshape(x_dims_post_transpose)?.transpose(0, 1)
} }
pub fn forward_t(&self, x: &Tensor, train: bool) -> Result<Tensor> {
if train {
self.forward_learning(x)
} else {
self.forward(x)
}
}
} }
impl crate::Module for BatchNorm { impl Module for BatchNorm {
fn forward(&self, x: &Tensor) -> Result<Tensor> { fn forward(&self, x: &Tensor) -> Result<Tensor> {
let target_shape: Vec<usize> = x let target_shape: Vec<usize> = x
.dims() .dims()
@ -170,9 +274,13 @@ impl crate::Module for BatchNorm {
.map(|(idx, v)| if idx == 1 { *v } else { 1 }) .map(|(idx, v)| if idx == 1 { *v } else { 1 })
.collect(); .collect();
let target_shape = target_shape.as_slice(); let target_shape = target_shape.as_slice();
let x = x let x = x
.broadcast_sub(&self.running_mean.reshape(target_shape)?)? .broadcast_sub(&self.running_mean.as_tensor().reshape(target_shape)?)?
.broadcast_div(&(self.running_var.reshape(target_shape)? + self.eps)?.sqrt()?)?; .broadcast_div(
&(self.running_var.as_tensor().reshape(target_shape)? + self.eps)?.sqrt()?,
)?;
match &self.weight_and_bias { match &self.weight_and_bias {
None => Ok(x), None => Ok(x),
Some((weight, bias)) => { Some((weight, bias)) => {
@ -193,21 +301,21 @@ pub fn batch_norm<C: Into<BatchNormConfig>>(
if config.eps < 0. { if config.eps < 0. {
candle::bail!("batch-norm eps cannot be negative {}", config.eps) candle::bail!("batch-norm eps cannot be negative {}", config.eps)
} }
let running_mean = vb.get_with_hints(num_features, "running_mean", crate::Init::Const(0.))?; let running_mean = vb.get_with_hints(num_features, "running_mean", Init::Const(0.))?;
let running_var = vb.get_with_hints(num_features, "running_var", crate::Init::Const(1.))?; let running_var = vb.get_with_hints(num_features, "running_var", Init::Const(1.))?;
let weight_and_bias = if config.affine { let weight_and_bias = if config.affine {
let weight = vb.get_with_hints(num_features, "weight", crate::Init::Const(1.))?; let weight = vb.get_with_hints(num_features, "weight", Init::Const(1.))?;
let bias = vb.get_with_hints(num_features, "bias", crate::Init::Const(0.))?; let bias = vb.get_with_hints(num_features, "bias", Init::Const(0.))?;
Some((weight, bias)) Some((weight, bias))
} else { } else {
None None
}; };
Ok(BatchNorm { Ok(BatchNorm {
running_mean, running_mean: Var::from_tensor(&running_mean)?,
running_var, running_var: Var::from_tensor(&running_var)?,
weight_and_bias, weight_and_bias,
remove_mean: config.remove_mean, remove_mean: config.remove_mean,
eps: config.eps, eps: config.eps,
num_features, momentum: config.momentum,
}) })
} }

View File

@ -16,6 +16,8 @@ input = torch.randn(2, 5, 3, 4)
output = m(input) output = m(input)
print(input.flatten()) print(input.flatten())
print(output.flatten()) print(output.flatten())
print(m.running_mean)
print(m.running_var)
*/ */
#[test] #[test]
fn batch_norm() -> Result<()> { fn batch_norm() -> Result<()> {
@ -71,5 +73,14 @@ fn batch_norm() -> Result<()> {
let diff2 = ((output2 - (output * 0.5)?)? + 1.5)?.sqr()?; let diff2 = ((output2 - (output * 0.5)?)? + 1.5)?.sqr()?;
let sum_diff2 = diff2.sum_keepdim(0)?; let sum_diff2 = diff2.sum_keepdim(0)?;
assert_eq!(test_utils::to_vec1_round(&sum_diff2, 4)?, &[0f32]); assert_eq!(test_utils::to_vec1_round(&sum_diff2, 4)?, &[0f32]);
assert_eq!(
test_utils::to_vec1_round(bn.running_mean(), 4)?,
&[-0.0133, 0.0197, -0.0153, -0.0073, -0.0020]
);
assert_eq!(
test_utils::to_vec1_round(bn.running_var(), 4)?,
&[0.9972, 0.9842, 0.9956, 0.9866, 0.9898]
);
Ok(()) Ok(())
} }