Small tweaks to batch-norm. (#1505)

This commit is contained in:
Laurent Mazare
2023-12-30 17:06:07 +01:00
committed by GitHub
parent 4290b81244
commit a0facd0e67

View File

@ -7,7 +7,6 @@
//! running stats. //! running stats.
//! //!
//! [`Batch Normalization`]: https://arxiv.org/abs/1502.03167 //! [`Batch Normalization`]: https://arxiv.org/abs/1502.03167
use crate::Init;
use candle::{DType, Module, Result, Tensor, Var}; use candle::{DType, Module, Result, Tensor, Var};
#[derive(Debug, Clone, Copy, PartialEq)] #[derive(Debug, Clone, Copy, PartialEq)]
@ -92,7 +91,6 @@ impl BatchNorm {
) )
} }
} }
Ok(()) Ok(())
} }
@ -217,34 +215,32 @@ impl BatchNorm {
let x = x.to_dtype(internal_dtype)?; let x = x.to_dtype(internal_dtype)?;
let x = x.transpose(0, 1)?; let x = x.transpose(0, 1)?;
let x_dims_post_transpose = x.dims(); let x_dims_post_transpose = x.dims();
// Flatten all the dimensions exception the channel one as this performs a Spatial Batch
// Normalization.
let x = x.flatten_from(1)?.contiguous()?; let x = x.flatten_from(1)?.contiguous()?;
let x = if self.remove_mean { let x = if self.remove_mean {
// The mean is taken over dim 1 as this is the batch dim after the transpose(0, 1) above.
let mean_x = x.mean_keepdim(1)?; let mean_x = x.mean_keepdim(1)?;
{ let updated_running_mean = ((self.running_mean.as_tensor() * (1.0 - self.momentum))?
// Update running mean + (mean_x.flatten_all()? * self.momentum)?)?;
let new_mean = ((self.running_mean.as_tensor() * (1.0 - self.momentum))? self.running_mean.set(&updated_running_mean)?;
+ (mean_x.flatten_all()? * self.momentum)?)?;
self.running_mean.set(&new_mean)?;
}
x.broadcast_sub(&mean_x)? x.broadcast_sub(&mean_x)?
} else { } else {
x x
}; };
// The mean is taken over dim 1 as this is the batch dim after the transpose(0, 1) above.
let norm_x = x.sqr()?.mean_keepdim(1)?; let norm_x = x.sqr()?.mean_keepdim(1)?;
{ let updated_running_var = {
// Update running variance
let batch_size = x.dim(1)? as f64; let batch_size = x.dim(1)? as f64;
let running_var_weight = 1.0 - self.momentum; let running_var_weight = 1.0 - self.momentum;
let norm_x_weight = self.momentum * batch_size / (batch_size - 1.0); let norm_x_weight = self.momentum * batch_size / (batch_size - 1.0);
((self.running_var.as_tensor() * running_var_weight)?
let new_var = ((self.running_var.as_tensor() * running_var_weight)? + (&norm_x.flatten_all()? * norm_x_weight)?)?
+ (&norm_x.flatten_all()? * norm_x_weight)?)?; };
self.running_var.set(&updated_running_var)?;
self.running_var.set(&new_var)?; let x = x
} .broadcast_div(&(norm_x + self.eps)?.sqrt()?)?
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?; .to_dtype(x_dtype)?;
let x = x_normed.to_dtype(x_dtype)?;
let x = match &self.weight_and_bias { let x = match &self.weight_and_bias {
None => x, None => x,
Some((weight, bias)) => { Some((weight, bias)) => {
@ -297,6 +293,7 @@ pub fn batch_norm<C: Into<BatchNormConfig>>(
config: C, config: C,
vb: crate::VarBuilder, vb: crate::VarBuilder,
) -> Result<BatchNorm> { ) -> Result<BatchNorm> {
use crate::Init;
let config = config.into(); let config = config.into();
if config.eps < 0. { if config.eps < 0. {
candle::bail!("batch-norm eps cannot be negative {}", config.eps) candle::bail!("batch-norm eps cannot be negative {}", config.eps)