mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 12:06:35 +00:00
[Breaking] Add training to batchnorm with exponential moving average (#1504)
* Add training to batchnorm with exponential moving average * Add more checks to batch norm * Resolve some review comments * Add with_momentum varients of `new` methods * Add check for range of momentum variable; update batch norm test * Run cargo fmt * Add back num_features parameter * Format; tiny simplification
This commit is contained in:
@ -7,15 +7,22 @@
|
||||
//! running stats.
|
||||
//!
|
||||
//! [`Batch Normalization`]: https://arxiv.org/abs/1502.03167
|
||||
use candle::{DType, Result, Tensor};
|
||||
use crate::Init;
|
||||
use candle::{DType, Module, Result, Tensor, Var};
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub struct BatchNormConfig {
|
||||
pub eps: f64,
|
||||
pub remove_mean: bool,
|
||||
|
||||
/// The meaning of affine here is different from LayerNorm: when false there is no learnable
|
||||
/// parameter at all, 1 used for gamma and 0 for beta.
|
||||
pub affine: bool,
|
||||
|
||||
/// Controls exponential moving average of running stats. Defaults to 0.1
|
||||
///
|
||||
/// `running_stat * (1.0 - momentum) + stat * momentum`.
|
||||
pub momentum: f64,
|
||||
}
|
||||
|
||||
impl Default for BatchNormConfig {
|
||||
@ -24,6 +31,7 @@ impl Default for BatchNormConfig {
|
||||
eps: 1e-5,
|
||||
remove_mean: true,
|
||||
affine: true,
|
||||
momentum: 0.1,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -32,23 +40,62 @@ impl From<f64> for BatchNormConfig {
|
||||
fn from(eps: f64) -> Self {
|
||||
Self {
|
||||
eps,
|
||||
remove_mean: true,
|
||||
affine: true,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct BatchNorm {
|
||||
running_mean: Tensor,
|
||||
running_var: Tensor,
|
||||
running_mean: Var,
|
||||
running_var: Var,
|
||||
weight_and_bias: Option<(Tensor, Tensor)>,
|
||||
remove_mean: bool,
|
||||
eps: f64,
|
||||
num_features: usize,
|
||||
momentum: f64,
|
||||
}
|
||||
|
||||
impl BatchNorm {
|
||||
fn check_validity(&self, num_features: usize) -> Result<()> {
|
||||
if self.eps < 0. {
|
||||
candle::bail!("batch-norm eps cannot be negative {}", self.eps)
|
||||
}
|
||||
if !(0.0..=1.0).contains(&self.momentum) {
|
||||
candle::bail!(
|
||||
"batch-norm momentum must be between 0 and 1, is {}",
|
||||
self.momentum
|
||||
)
|
||||
}
|
||||
if self.running_mean.dims() != [num_features] {
|
||||
candle::bail!(
|
||||
"batch-norm running mean has unexpected shape {:?} should have shape [{num_features}]",
|
||||
self.running_mean.shape(),
|
||||
)
|
||||
}
|
||||
if self.running_var.dims() != [num_features] {
|
||||
candle::bail!(
|
||||
"batch-norm running variance has unexpected shape {:?} should have shape [{num_features}]",
|
||||
self.running_var.shape(),
|
||||
)
|
||||
}
|
||||
if let Some((ref weight, ref bias)) = self.weight_and_bias.as_ref() {
|
||||
if weight.dims() != [num_features] {
|
||||
candle::bail!(
|
||||
"batch-norm weight has unexpected shape {:?} should have shape [{num_features}]",
|
||||
weight.shape(),
|
||||
)
|
||||
}
|
||||
if bias.dims() != [num_features] {
|
||||
candle::bail!(
|
||||
"batch-norm weight has unexpected shape {:?} should have shape [{num_features}]",
|
||||
bias.shape(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn new(
|
||||
num_features: usize,
|
||||
running_mean: Tensor,
|
||||
@ -57,29 +104,16 @@ impl BatchNorm {
|
||||
bias: Tensor,
|
||||
eps: f64,
|
||||
) -> Result<Self> {
|
||||
if eps < 0. {
|
||||
candle::bail!("batch-norm eps cannot be negative {eps}")
|
||||
}
|
||||
if weight.dims() != [num_features] {
|
||||
candle::bail!(
|
||||
"batch-norm unexpected weight shape {:?} {num_features}",
|
||||
weight.shape()
|
||||
)
|
||||
}
|
||||
if bias.dims() != [num_features] {
|
||||
candle::bail!(
|
||||
"batch-norm unexpected bias shape {:?} {num_features}",
|
||||
bias.shape()
|
||||
)
|
||||
}
|
||||
Ok(Self {
|
||||
running_mean,
|
||||
running_var,
|
||||
let out = Self {
|
||||
running_mean: Var::from_tensor(&running_mean)?,
|
||||
running_var: Var::from_tensor(&running_var)?,
|
||||
weight_and_bias: Some((weight, bias)),
|
||||
remove_mean: true,
|
||||
eps,
|
||||
num_features,
|
||||
})
|
||||
momentum: 0.1,
|
||||
};
|
||||
out.check_validity(num_features)?;
|
||||
Ok(out)
|
||||
}
|
||||
|
||||
pub fn new_no_bias(
|
||||
@ -88,25 +122,64 @@ impl BatchNorm {
|
||||
running_var: Tensor,
|
||||
eps: f64,
|
||||
) -> Result<Self> {
|
||||
if eps < 0. {
|
||||
candle::bail!("batch-norm eps cannot be negative {eps}")
|
||||
}
|
||||
Ok(Self {
|
||||
running_mean,
|
||||
running_var,
|
||||
let out = Self {
|
||||
running_mean: Var::from_tensor(&running_mean)?,
|
||||
running_var: Var::from_tensor(&running_var)?,
|
||||
weight_and_bias: None,
|
||||
remove_mean: true,
|
||||
eps,
|
||||
num_features,
|
||||
})
|
||||
momentum: 0.1,
|
||||
};
|
||||
out.check_validity(num_features)?;
|
||||
Ok(out)
|
||||
}
|
||||
|
||||
pub fn new_with_momentum(
|
||||
num_features: usize,
|
||||
running_mean: Tensor,
|
||||
running_var: Tensor,
|
||||
weight: Tensor,
|
||||
bias: Tensor,
|
||||
eps: f64,
|
||||
momentum: f64,
|
||||
) -> Result<Self> {
|
||||
let out = Self {
|
||||
running_mean: Var::from_tensor(&running_mean)?,
|
||||
running_var: Var::from_tensor(&running_var)?,
|
||||
weight_and_bias: Some((weight, bias)),
|
||||
remove_mean: true,
|
||||
eps,
|
||||
momentum,
|
||||
};
|
||||
out.check_validity(num_features)?;
|
||||
Ok(out)
|
||||
}
|
||||
|
||||
pub fn new_no_bias_with_momentum(
|
||||
num_features: usize,
|
||||
running_mean: Tensor,
|
||||
running_var: Tensor,
|
||||
eps: f64,
|
||||
momentum: f64,
|
||||
) -> Result<Self> {
|
||||
let out = Self {
|
||||
running_mean: Var::from_tensor(&running_mean)?,
|
||||
running_var: Var::from_tensor(&running_var)?,
|
||||
weight_and_bias: None,
|
||||
remove_mean: true,
|
||||
eps,
|
||||
momentum,
|
||||
};
|
||||
out.check_validity(num_features)?;
|
||||
Ok(out)
|
||||
}
|
||||
|
||||
pub fn running_mean(&self) -> &Tensor {
|
||||
&self.running_mean
|
||||
self.running_mean.as_tensor()
|
||||
}
|
||||
|
||||
pub fn running_var(&self) -> &Tensor {
|
||||
&self.running_var
|
||||
self.running_var.as_tensor()
|
||||
}
|
||||
|
||||
pub fn eps(&self) -> f64 {
|
||||
@ -117,7 +190,12 @@ impl BatchNorm {
|
||||
self.weight_and_bias.as_ref().map(|v| (&v.0, &v.1))
|
||||
}
|
||||
|
||||
pub fn momentum(&self) -> f64 {
|
||||
self.momentum
|
||||
}
|
||||
|
||||
pub fn forward_learning(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let num_features = self.running_mean.as_tensor().dim(0)?;
|
||||
let x_dtype = x.dtype();
|
||||
let internal_dtype = match x_dtype {
|
||||
DType::F16 | DType::BF16 => DType::F32,
|
||||
@ -129,11 +207,11 @@ impl BatchNorm {
|
||||
x.shape()
|
||||
)
|
||||
}
|
||||
if x.dim(1)? != self.num_features {
|
||||
if x.dim(1)? != num_features {
|
||||
candle::bail!(
|
||||
"batch-norm input doesn't have the expected number of features ({:?} <> {})",
|
||||
x.shape(),
|
||||
self.num_features
|
||||
num_features
|
||||
)
|
||||
}
|
||||
let x = x.to_dtype(internal_dtype)?;
|
||||
@ -142,26 +220,52 @@ impl BatchNorm {
|
||||
let x = x.flatten_from(1)?.contiguous()?;
|
||||
let x = if self.remove_mean {
|
||||
let mean_x = x.mean_keepdim(1)?;
|
||||
{
|
||||
// Update running mean
|
||||
let new_mean = ((self.running_mean.as_tensor() * (1.0 - self.momentum))?
|
||||
+ (mean_x.flatten_all()? * self.momentum)?)?;
|
||||
|
||||
self.running_mean.set(&new_mean)?;
|
||||
}
|
||||
x.broadcast_sub(&mean_x)?
|
||||
} else {
|
||||
x
|
||||
};
|
||||
let norm_x = x.sqr()?.mean_keepdim(1)?;
|
||||
{
|
||||
// Update running variance
|
||||
let batch_size = x.dim(1)? as f64;
|
||||
let running_var_weight = 1.0 - self.momentum;
|
||||
let norm_x_weight = self.momentum * batch_size / (batch_size - 1.0);
|
||||
|
||||
let new_var = ((self.running_var.as_tensor() * running_var_weight)?
|
||||
+ (&norm_x.flatten_all()? * norm_x_weight)?)?;
|
||||
|
||||
self.running_var.set(&new_var)?;
|
||||
}
|
||||
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
|
||||
let x = x_normed.to_dtype(x_dtype)?;
|
||||
let x = match &self.weight_and_bias {
|
||||
None => x,
|
||||
Some((weight, bias)) => {
|
||||
let weight = weight.reshape((self.num_features, 1))?;
|
||||
let bias = bias.reshape((self.num_features, 1))?;
|
||||
let weight = weight.reshape(((), 1))?;
|
||||
let bias = bias.reshape(((), 1))?;
|
||||
x.broadcast_mul(&weight)?.broadcast_add(&bias)?
|
||||
}
|
||||
};
|
||||
x.reshape(x_dims_post_transpose)?.transpose(0, 1)
|
||||
}
|
||||
|
||||
pub fn forward_t(&self, x: &Tensor, train: bool) -> Result<Tensor> {
|
||||
if train {
|
||||
self.forward_learning(x)
|
||||
} else {
|
||||
self.forward(x)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl crate::Module for BatchNorm {
|
||||
impl Module for BatchNorm {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let target_shape: Vec<usize> = x
|
||||
.dims()
|
||||
@ -170,9 +274,13 @@ impl crate::Module for BatchNorm {
|
||||
.map(|(idx, v)| if idx == 1 { *v } else { 1 })
|
||||
.collect();
|
||||
let target_shape = target_shape.as_slice();
|
||||
|
||||
let x = x
|
||||
.broadcast_sub(&self.running_mean.reshape(target_shape)?)?
|
||||
.broadcast_div(&(self.running_var.reshape(target_shape)? + self.eps)?.sqrt()?)?;
|
||||
.broadcast_sub(&self.running_mean.as_tensor().reshape(target_shape)?)?
|
||||
.broadcast_div(
|
||||
&(self.running_var.as_tensor().reshape(target_shape)? + self.eps)?.sqrt()?,
|
||||
)?;
|
||||
|
||||
match &self.weight_and_bias {
|
||||
None => Ok(x),
|
||||
Some((weight, bias)) => {
|
||||
@ -193,21 +301,21 @@ pub fn batch_norm<C: Into<BatchNormConfig>>(
|
||||
if config.eps < 0. {
|
||||
candle::bail!("batch-norm eps cannot be negative {}", config.eps)
|
||||
}
|
||||
let running_mean = vb.get_with_hints(num_features, "running_mean", crate::Init::Const(0.))?;
|
||||
let running_var = vb.get_with_hints(num_features, "running_var", crate::Init::Const(1.))?;
|
||||
let running_mean = vb.get_with_hints(num_features, "running_mean", Init::Const(0.))?;
|
||||
let running_var = vb.get_with_hints(num_features, "running_var", Init::Const(1.))?;
|
||||
let weight_and_bias = if config.affine {
|
||||
let weight = vb.get_with_hints(num_features, "weight", crate::Init::Const(1.))?;
|
||||
let bias = vb.get_with_hints(num_features, "bias", crate::Init::Const(0.))?;
|
||||
let weight = vb.get_with_hints(num_features, "weight", Init::Const(1.))?;
|
||||
let bias = vb.get_with_hints(num_features, "bias", Init::Const(0.))?;
|
||||
Some((weight, bias))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(BatchNorm {
|
||||
running_mean,
|
||||
running_var,
|
||||
running_mean: Var::from_tensor(&running_mean)?,
|
||||
running_var: Var::from_tensor(&running_var)?,
|
||||
weight_and_bias,
|
||||
remove_mean: config.remove_mean,
|
||||
eps: config.eps,
|
||||
num_features,
|
||||
momentum: config.momentum,
|
||||
})
|
||||
}
|
||||
|
Reference in New Issue
Block a user