mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Small tweaks to batch-norm. (#1505)
This commit is contained in:
@ -7,7 +7,6 @@
|
||||
//! running stats.
|
||||
//!
|
||||
//! [`Batch Normalization`]: https://arxiv.org/abs/1502.03167
|
||||
use crate::Init;
|
||||
use candle::{DType, Module, Result, Tensor, Var};
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
@ -92,7 +91,6 @@ impl BatchNorm {
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -217,34 +215,32 @@ impl BatchNorm {
|
||||
let x = x.to_dtype(internal_dtype)?;
|
||||
let x = x.transpose(0, 1)?;
|
||||
let x_dims_post_transpose = x.dims();
|
||||
// Flatten all the dimensions exception the channel one as this performs a Spatial Batch
|
||||
// Normalization.
|
||||
let x = x.flatten_from(1)?.contiguous()?;
|
||||
let x = if self.remove_mean {
|
||||
// The mean is taken over dim 1 as this is the batch dim after the transpose(0, 1) above.
|
||||
let mean_x = x.mean_keepdim(1)?;
|
||||
{
|
||||
// Update running mean
|
||||
let new_mean = ((self.running_mean.as_tensor() * (1.0 - self.momentum))?
|
||||
+ (mean_x.flatten_all()? * self.momentum)?)?;
|
||||
|
||||
self.running_mean.set(&new_mean)?;
|
||||
}
|
||||
let updated_running_mean = ((self.running_mean.as_tensor() * (1.0 - self.momentum))?
|
||||
+ (mean_x.flatten_all()? * self.momentum)?)?;
|
||||
self.running_mean.set(&updated_running_mean)?;
|
||||
x.broadcast_sub(&mean_x)?
|
||||
} else {
|
||||
x
|
||||
};
|
||||
// The mean is taken over dim 1 as this is the batch dim after the transpose(0, 1) above.
|
||||
let norm_x = x.sqr()?.mean_keepdim(1)?;
|
||||
{
|
||||
// Update running variance
|
||||
let updated_running_var = {
|
||||
let batch_size = x.dim(1)? as f64;
|
||||
let running_var_weight = 1.0 - self.momentum;
|
||||
let norm_x_weight = self.momentum * batch_size / (batch_size - 1.0);
|
||||
|
||||
let new_var = ((self.running_var.as_tensor() * running_var_weight)?
|
||||
+ (&norm_x.flatten_all()? * norm_x_weight)?)?;
|
||||
|
||||
self.running_var.set(&new_var)?;
|
||||
}
|
||||
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
|
||||
let x = x_normed.to_dtype(x_dtype)?;
|
||||
((self.running_var.as_tensor() * running_var_weight)?
|
||||
+ (&norm_x.flatten_all()? * norm_x_weight)?)?
|
||||
};
|
||||
self.running_var.set(&updated_running_var)?;
|
||||
let x = x
|
||||
.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?
|
||||
.to_dtype(x_dtype)?;
|
||||
let x = match &self.weight_and_bias {
|
||||
None => x,
|
||||
Some((weight, bias)) => {
|
||||
@ -297,6 +293,7 @@ pub fn batch_norm<C: Into<BatchNormConfig>>(
|
||||
config: C,
|
||||
vb: crate::VarBuilder,
|
||||
) -> Result<BatchNorm> {
|
||||
use crate::Init;
|
||||
let config = config.into();
|
||||
if config.eps < 0. {
|
||||
candle::bail!("batch-norm eps cannot be negative {}", config.eps)
|
||||
|
Reference in New Issue
Block a user