Small tweaks to batch-norm. (#1505)

This commit is contained in:
Laurent Mazare
2023-12-30 17:06:07 +01:00
committed by GitHub
parent 4290b81244
commit a0facd0e67

View File

@ -7,7 +7,6 @@
//! running stats.
//!
//! [`Batch Normalization`]: https://arxiv.org/abs/1502.03167
use crate::Init;
use candle::{DType, Module, Result, Tensor, Var};
#[derive(Debug, Clone, Copy, PartialEq)]
@ -92,7 +91,6 @@ impl BatchNorm {
)
}
}
Ok(())
}
@ -217,34 +215,32 @@ impl BatchNorm {
let x = x.to_dtype(internal_dtype)?;
let x = x.transpose(0, 1)?;
let x_dims_post_transpose = x.dims();
// Flatten all the dimensions exception the channel one as this performs a Spatial Batch
// Normalization.
let x = x.flatten_from(1)?.contiguous()?;
let x = if self.remove_mean {
// The mean is taken over dim 1 as this is the batch dim after the transpose(0, 1) above.
let mean_x = x.mean_keepdim(1)?;
{
// Update running mean
let new_mean = ((self.running_mean.as_tensor() * (1.0 - self.momentum))?
let updated_running_mean = ((self.running_mean.as_tensor() * (1.0 - self.momentum))?
+ (mean_x.flatten_all()? * self.momentum)?)?;
self.running_mean.set(&new_mean)?;
}
self.running_mean.set(&updated_running_mean)?;
x.broadcast_sub(&mean_x)?
} else {
x
};
// The mean is taken over dim 1 as this is the batch dim after the transpose(0, 1) above.
let norm_x = x.sqr()?.mean_keepdim(1)?;
{
// Update running variance
let updated_running_var = {
let batch_size = x.dim(1)? as f64;
let running_var_weight = 1.0 - self.momentum;
let norm_x_weight = self.momentum * batch_size / (batch_size - 1.0);
let new_var = ((self.running_var.as_tensor() * running_var_weight)?
+ (&norm_x.flatten_all()? * norm_x_weight)?)?;
self.running_var.set(&new_var)?;
}
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
let x = x_normed.to_dtype(x_dtype)?;
((self.running_var.as_tensor() * running_var_weight)?
+ (&norm_x.flatten_all()? * norm_x_weight)?)?
};
self.running_var.set(&updated_running_var)?;
let x = x
.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?
.to_dtype(x_dtype)?;
let x = match &self.weight_and_bias {
None => x,
Some((weight, bias)) => {
@ -297,6 +293,7 @@ pub fn batch_norm<C: Into<BatchNormConfig>>(
config: C,
vb: crate::VarBuilder,
) -> Result<BatchNorm> {
use crate::Init;
let config = config.into();
if config.eps < 0. {
candle::bail!("batch-norm eps cannot be negative {}", config.eps)