mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 04:00:28 +00:00
[Breaking] Add training to batchnorm with exponential moving average (#1504)
* Add training to batchnorm with exponential moving average * Add more checks to batch norm * Resolve some review comments * Add with_momentum varients of `new` methods * Add check for range of momentum variable; update batch norm test * Run cargo fmt * Add back num_features parameter * Format; tiny simplification
This commit is contained in:
@ -7,15 +7,22 @@
|
|||||||
//! running stats.
|
//! running stats.
|
||||||
//!
|
//!
|
||||||
//! [`Batch Normalization`]: https://arxiv.org/abs/1502.03167
|
//! [`Batch Normalization`]: https://arxiv.org/abs/1502.03167
|
||||||
use candle::{DType, Result, Tensor};
|
use crate::Init;
|
||||||
|
use candle::{DType, Module, Result, Tensor, Var};
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||||
pub struct BatchNormConfig {
|
pub struct BatchNormConfig {
|
||||||
pub eps: f64,
|
pub eps: f64,
|
||||||
pub remove_mean: bool,
|
pub remove_mean: bool,
|
||||||
|
|
||||||
/// The meaning of affine here is different from LayerNorm: when false there is no learnable
|
/// The meaning of affine here is different from LayerNorm: when false there is no learnable
|
||||||
/// parameter at all, 1 used for gamma and 0 for beta.
|
/// parameter at all, 1 used for gamma and 0 for beta.
|
||||||
pub affine: bool,
|
pub affine: bool,
|
||||||
|
|
||||||
|
/// Controls exponential moving average of running stats. Defaults to 0.1
|
||||||
|
///
|
||||||
|
/// `running_stat * (1.0 - momentum) + stat * momentum`.
|
||||||
|
pub momentum: f64,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for BatchNormConfig {
|
impl Default for BatchNormConfig {
|
||||||
@ -24,6 +31,7 @@ impl Default for BatchNormConfig {
|
|||||||
eps: 1e-5,
|
eps: 1e-5,
|
||||||
remove_mean: true,
|
remove_mean: true,
|
||||||
affine: true,
|
affine: true,
|
||||||
|
momentum: 0.1,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -32,23 +40,62 @@ impl From<f64> for BatchNormConfig {
|
|||||||
fn from(eps: f64) -> Self {
|
fn from(eps: f64) -> Self {
|
||||||
Self {
|
Self {
|
||||||
eps,
|
eps,
|
||||||
remove_mean: true,
|
..Default::default()
|
||||||
affine: true,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
pub struct BatchNorm {
|
pub struct BatchNorm {
|
||||||
running_mean: Tensor,
|
running_mean: Var,
|
||||||
running_var: Tensor,
|
running_var: Var,
|
||||||
weight_and_bias: Option<(Tensor, Tensor)>,
|
weight_and_bias: Option<(Tensor, Tensor)>,
|
||||||
remove_mean: bool,
|
remove_mean: bool,
|
||||||
eps: f64,
|
eps: f64,
|
||||||
num_features: usize,
|
momentum: f64,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl BatchNorm {
|
impl BatchNorm {
|
||||||
|
fn check_validity(&self, num_features: usize) -> Result<()> {
|
||||||
|
if self.eps < 0. {
|
||||||
|
candle::bail!("batch-norm eps cannot be negative {}", self.eps)
|
||||||
|
}
|
||||||
|
if !(0.0..=1.0).contains(&self.momentum) {
|
||||||
|
candle::bail!(
|
||||||
|
"batch-norm momentum must be between 0 and 1, is {}",
|
||||||
|
self.momentum
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if self.running_mean.dims() != [num_features] {
|
||||||
|
candle::bail!(
|
||||||
|
"batch-norm running mean has unexpected shape {:?} should have shape [{num_features}]",
|
||||||
|
self.running_mean.shape(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if self.running_var.dims() != [num_features] {
|
||||||
|
candle::bail!(
|
||||||
|
"batch-norm running variance has unexpected shape {:?} should have shape [{num_features}]",
|
||||||
|
self.running_var.shape(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if let Some((ref weight, ref bias)) = self.weight_and_bias.as_ref() {
|
||||||
|
if weight.dims() != [num_features] {
|
||||||
|
candle::bail!(
|
||||||
|
"batch-norm weight has unexpected shape {:?} should have shape [{num_features}]",
|
||||||
|
weight.shape(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if bias.dims() != [num_features] {
|
||||||
|
candle::bail!(
|
||||||
|
"batch-norm weight has unexpected shape {:?} should have shape [{num_features}]",
|
||||||
|
bias.shape(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
pub fn new(
|
pub fn new(
|
||||||
num_features: usize,
|
num_features: usize,
|
||||||
running_mean: Tensor,
|
running_mean: Tensor,
|
||||||
@ -57,29 +104,16 @@ impl BatchNorm {
|
|||||||
bias: Tensor,
|
bias: Tensor,
|
||||||
eps: f64,
|
eps: f64,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
if eps < 0. {
|
let out = Self {
|
||||||
candle::bail!("batch-norm eps cannot be negative {eps}")
|
running_mean: Var::from_tensor(&running_mean)?,
|
||||||
}
|
running_var: Var::from_tensor(&running_var)?,
|
||||||
if weight.dims() != [num_features] {
|
|
||||||
candle::bail!(
|
|
||||||
"batch-norm unexpected weight shape {:?} {num_features}",
|
|
||||||
weight.shape()
|
|
||||||
)
|
|
||||||
}
|
|
||||||
if bias.dims() != [num_features] {
|
|
||||||
candle::bail!(
|
|
||||||
"batch-norm unexpected bias shape {:?} {num_features}",
|
|
||||||
bias.shape()
|
|
||||||
)
|
|
||||||
}
|
|
||||||
Ok(Self {
|
|
||||||
running_mean,
|
|
||||||
running_var,
|
|
||||||
weight_and_bias: Some((weight, bias)),
|
weight_and_bias: Some((weight, bias)),
|
||||||
remove_mean: true,
|
remove_mean: true,
|
||||||
eps,
|
eps,
|
||||||
num_features,
|
momentum: 0.1,
|
||||||
})
|
};
|
||||||
|
out.check_validity(num_features)?;
|
||||||
|
Ok(out)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn new_no_bias(
|
pub fn new_no_bias(
|
||||||
@ -88,25 +122,64 @@ impl BatchNorm {
|
|||||||
running_var: Tensor,
|
running_var: Tensor,
|
||||||
eps: f64,
|
eps: f64,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
if eps < 0. {
|
let out = Self {
|
||||||
candle::bail!("batch-norm eps cannot be negative {eps}")
|
running_mean: Var::from_tensor(&running_mean)?,
|
||||||
}
|
running_var: Var::from_tensor(&running_var)?,
|
||||||
Ok(Self {
|
|
||||||
running_mean,
|
|
||||||
running_var,
|
|
||||||
weight_and_bias: None,
|
weight_and_bias: None,
|
||||||
remove_mean: true,
|
remove_mean: true,
|
||||||
eps,
|
eps,
|
||||||
num_features,
|
momentum: 0.1,
|
||||||
})
|
};
|
||||||
|
out.check_validity(num_features)?;
|
||||||
|
Ok(out)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn new_with_momentum(
|
||||||
|
num_features: usize,
|
||||||
|
running_mean: Tensor,
|
||||||
|
running_var: Tensor,
|
||||||
|
weight: Tensor,
|
||||||
|
bias: Tensor,
|
||||||
|
eps: f64,
|
||||||
|
momentum: f64,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let out = Self {
|
||||||
|
running_mean: Var::from_tensor(&running_mean)?,
|
||||||
|
running_var: Var::from_tensor(&running_var)?,
|
||||||
|
weight_and_bias: Some((weight, bias)),
|
||||||
|
remove_mean: true,
|
||||||
|
eps,
|
||||||
|
momentum,
|
||||||
|
};
|
||||||
|
out.check_validity(num_features)?;
|
||||||
|
Ok(out)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn new_no_bias_with_momentum(
|
||||||
|
num_features: usize,
|
||||||
|
running_mean: Tensor,
|
||||||
|
running_var: Tensor,
|
||||||
|
eps: f64,
|
||||||
|
momentum: f64,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let out = Self {
|
||||||
|
running_mean: Var::from_tensor(&running_mean)?,
|
||||||
|
running_var: Var::from_tensor(&running_var)?,
|
||||||
|
weight_and_bias: None,
|
||||||
|
remove_mean: true,
|
||||||
|
eps,
|
||||||
|
momentum,
|
||||||
|
};
|
||||||
|
out.check_validity(num_features)?;
|
||||||
|
Ok(out)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn running_mean(&self) -> &Tensor {
|
pub fn running_mean(&self) -> &Tensor {
|
||||||
&self.running_mean
|
self.running_mean.as_tensor()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn running_var(&self) -> &Tensor {
|
pub fn running_var(&self) -> &Tensor {
|
||||||
&self.running_var
|
self.running_var.as_tensor()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn eps(&self) -> f64 {
|
pub fn eps(&self) -> f64 {
|
||||||
@ -117,7 +190,12 @@ impl BatchNorm {
|
|||||||
self.weight_and_bias.as_ref().map(|v| (&v.0, &v.1))
|
self.weight_and_bias.as_ref().map(|v| (&v.0, &v.1))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn momentum(&self) -> f64 {
|
||||||
|
self.momentum
|
||||||
|
}
|
||||||
|
|
||||||
pub fn forward_learning(&self, x: &Tensor) -> Result<Tensor> {
|
pub fn forward_learning(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
|
let num_features = self.running_mean.as_tensor().dim(0)?;
|
||||||
let x_dtype = x.dtype();
|
let x_dtype = x.dtype();
|
||||||
let internal_dtype = match x_dtype {
|
let internal_dtype = match x_dtype {
|
||||||
DType::F16 | DType::BF16 => DType::F32,
|
DType::F16 | DType::BF16 => DType::F32,
|
||||||
@ -129,11 +207,11 @@ impl BatchNorm {
|
|||||||
x.shape()
|
x.shape()
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
if x.dim(1)? != self.num_features {
|
if x.dim(1)? != num_features {
|
||||||
candle::bail!(
|
candle::bail!(
|
||||||
"batch-norm input doesn't have the expected number of features ({:?} <> {})",
|
"batch-norm input doesn't have the expected number of features ({:?} <> {})",
|
||||||
x.shape(),
|
x.shape(),
|
||||||
self.num_features
|
num_features
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
let x = x.to_dtype(internal_dtype)?;
|
let x = x.to_dtype(internal_dtype)?;
|
||||||
@ -142,26 +220,52 @@ impl BatchNorm {
|
|||||||
let x = x.flatten_from(1)?.contiguous()?;
|
let x = x.flatten_from(1)?.contiguous()?;
|
||||||
let x = if self.remove_mean {
|
let x = if self.remove_mean {
|
||||||
let mean_x = x.mean_keepdim(1)?;
|
let mean_x = x.mean_keepdim(1)?;
|
||||||
|
{
|
||||||
|
// Update running mean
|
||||||
|
let new_mean = ((self.running_mean.as_tensor() * (1.0 - self.momentum))?
|
||||||
|
+ (mean_x.flatten_all()? * self.momentum)?)?;
|
||||||
|
|
||||||
|
self.running_mean.set(&new_mean)?;
|
||||||
|
}
|
||||||
x.broadcast_sub(&mean_x)?
|
x.broadcast_sub(&mean_x)?
|
||||||
} else {
|
} else {
|
||||||
x
|
x
|
||||||
};
|
};
|
||||||
let norm_x = x.sqr()?.mean_keepdim(1)?;
|
let norm_x = x.sqr()?.mean_keepdim(1)?;
|
||||||
|
{
|
||||||
|
// Update running variance
|
||||||
|
let batch_size = x.dim(1)? as f64;
|
||||||
|
let running_var_weight = 1.0 - self.momentum;
|
||||||
|
let norm_x_weight = self.momentum * batch_size / (batch_size - 1.0);
|
||||||
|
|
||||||
|
let new_var = ((self.running_var.as_tensor() * running_var_weight)?
|
||||||
|
+ (&norm_x.flatten_all()? * norm_x_weight)?)?;
|
||||||
|
|
||||||
|
self.running_var.set(&new_var)?;
|
||||||
|
}
|
||||||
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
|
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
|
||||||
let x = x_normed.to_dtype(x_dtype)?;
|
let x = x_normed.to_dtype(x_dtype)?;
|
||||||
let x = match &self.weight_and_bias {
|
let x = match &self.weight_and_bias {
|
||||||
None => x,
|
None => x,
|
||||||
Some((weight, bias)) => {
|
Some((weight, bias)) => {
|
||||||
let weight = weight.reshape((self.num_features, 1))?;
|
let weight = weight.reshape(((), 1))?;
|
||||||
let bias = bias.reshape((self.num_features, 1))?;
|
let bias = bias.reshape(((), 1))?;
|
||||||
x.broadcast_mul(&weight)?.broadcast_add(&bias)?
|
x.broadcast_mul(&weight)?.broadcast_add(&bias)?
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
x.reshape(x_dims_post_transpose)?.transpose(0, 1)
|
x.reshape(x_dims_post_transpose)?.transpose(0, 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn forward_t(&self, x: &Tensor, train: bool) -> Result<Tensor> {
|
||||||
|
if train {
|
||||||
|
self.forward_learning(x)
|
||||||
|
} else {
|
||||||
|
self.forward(x)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl crate::Module for BatchNorm {
|
impl Module for BatchNorm {
|
||||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
let target_shape: Vec<usize> = x
|
let target_shape: Vec<usize> = x
|
||||||
.dims()
|
.dims()
|
||||||
@ -170,9 +274,13 @@ impl crate::Module for BatchNorm {
|
|||||||
.map(|(idx, v)| if idx == 1 { *v } else { 1 })
|
.map(|(idx, v)| if idx == 1 { *v } else { 1 })
|
||||||
.collect();
|
.collect();
|
||||||
let target_shape = target_shape.as_slice();
|
let target_shape = target_shape.as_slice();
|
||||||
|
|
||||||
let x = x
|
let x = x
|
||||||
.broadcast_sub(&self.running_mean.reshape(target_shape)?)?
|
.broadcast_sub(&self.running_mean.as_tensor().reshape(target_shape)?)?
|
||||||
.broadcast_div(&(self.running_var.reshape(target_shape)? + self.eps)?.sqrt()?)?;
|
.broadcast_div(
|
||||||
|
&(self.running_var.as_tensor().reshape(target_shape)? + self.eps)?.sqrt()?,
|
||||||
|
)?;
|
||||||
|
|
||||||
match &self.weight_and_bias {
|
match &self.weight_and_bias {
|
||||||
None => Ok(x),
|
None => Ok(x),
|
||||||
Some((weight, bias)) => {
|
Some((weight, bias)) => {
|
||||||
@ -193,21 +301,21 @@ pub fn batch_norm<C: Into<BatchNormConfig>>(
|
|||||||
if config.eps < 0. {
|
if config.eps < 0. {
|
||||||
candle::bail!("batch-norm eps cannot be negative {}", config.eps)
|
candle::bail!("batch-norm eps cannot be negative {}", config.eps)
|
||||||
}
|
}
|
||||||
let running_mean = vb.get_with_hints(num_features, "running_mean", crate::Init::Const(0.))?;
|
let running_mean = vb.get_with_hints(num_features, "running_mean", Init::Const(0.))?;
|
||||||
let running_var = vb.get_with_hints(num_features, "running_var", crate::Init::Const(1.))?;
|
let running_var = vb.get_with_hints(num_features, "running_var", Init::Const(1.))?;
|
||||||
let weight_and_bias = if config.affine {
|
let weight_and_bias = if config.affine {
|
||||||
let weight = vb.get_with_hints(num_features, "weight", crate::Init::Const(1.))?;
|
let weight = vb.get_with_hints(num_features, "weight", Init::Const(1.))?;
|
||||||
let bias = vb.get_with_hints(num_features, "bias", crate::Init::Const(0.))?;
|
let bias = vb.get_with_hints(num_features, "bias", Init::Const(0.))?;
|
||||||
Some((weight, bias))
|
Some((weight, bias))
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
Ok(BatchNorm {
|
Ok(BatchNorm {
|
||||||
running_mean,
|
running_mean: Var::from_tensor(&running_mean)?,
|
||||||
running_var,
|
running_var: Var::from_tensor(&running_var)?,
|
||||||
weight_and_bias,
|
weight_and_bias,
|
||||||
remove_mean: config.remove_mean,
|
remove_mean: config.remove_mean,
|
||||||
eps: config.eps,
|
eps: config.eps,
|
||||||
num_features,
|
momentum: config.momentum,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -16,6 +16,8 @@ input = torch.randn(2, 5, 3, 4)
|
|||||||
output = m(input)
|
output = m(input)
|
||||||
print(input.flatten())
|
print(input.flatten())
|
||||||
print(output.flatten())
|
print(output.flatten())
|
||||||
|
print(m.running_mean)
|
||||||
|
print(m.running_var)
|
||||||
*/
|
*/
|
||||||
#[test]
|
#[test]
|
||||||
fn batch_norm() -> Result<()> {
|
fn batch_norm() -> Result<()> {
|
||||||
@ -71,5 +73,14 @@ fn batch_norm() -> Result<()> {
|
|||||||
let diff2 = ((output2 - (output * 0.5)?)? + 1.5)?.sqr()?;
|
let diff2 = ((output2 - (output * 0.5)?)? + 1.5)?.sqr()?;
|
||||||
let sum_diff2 = diff2.sum_keepdim(0)?;
|
let sum_diff2 = diff2.sum_keepdim(0)?;
|
||||||
assert_eq!(test_utils::to_vec1_round(&sum_diff2, 4)?, &[0f32]);
|
assert_eq!(test_utils::to_vec1_round(&sum_diff2, 4)?, &[0f32]);
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec1_round(bn.running_mean(), 4)?,
|
||||||
|
&[-0.0133, 0.0197, -0.0153, -0.0073, -0.0020]
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec1_round(bn.running_var(), 4)?,
|
||||||
|
&[0.9972, 0.9842, 0.9956, 0.9866, 0.9898]
|
||||||
|
);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user