mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Small tweaks to batch-norm. (#1505)
This commit is contained in:
@ -7,7 +7,6 @@
|
|||||||
//! running stats.
|
//! running stats.
|
||||||
//!
|
//!
|
||||||
//! [`Batch Normalization`]: https://arxiv.org/abs/1502.03167
|
//! [`Batch Normalization`]: https://arxiv.org/abs/1502.03167
|
||||||
use crate::Init;
|
|
||||||
use candle::{DType, Module, Result, Tensor, Var};
|
use candle::{DType, Module, Result, Tensor, Var};
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||||
@ -92,7 +91,6 @@ impl BatchNorm {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -217,34 +215,32 @@ impl BatchNorm {
|
|||||||
let x = x.to_dtype(internal_dtype)?;
|
let x = x.to_dtype(internal_dtype)?;
|
||||||
let x = x.transpose(0, 1)?;
|
let x = x.transpose(0, 1)?;
|
||||||
let x_dims_post_transpose = x.dims();
|
let x_dims_post_transpose = x.dims();
|
||||||
|
// Flatten all the dimensions exception the channel one as this performs a Spatial Batch
|
||||||
|
// Normalization.
|
||||||
let x = x.flatten_from(1)?.contiguous()?;
|
let x = x.flatten_from(1)?.contiguous()?;
|
||||||
let x = if self.remove_mean {
|
let x = if self.remove_mean {
|
||||||
|
// The mean is taken over dim 1 as this is the batch dim after the transpose(0, 1) above.
|
||||||
let mean_x = x.mean_keepdim(1)?;
|
let mean_x = x.mean_keepdim(1)?;
|
||||||
{
|
let updated_running_mean = ((self.running_mean.as_tensor() * (1.0 - self.momentum))?
|
||||||
// Update running mean
|
+ (mean_x.flatten_all()? * self.momentum)?)?;
|
||||||
let new_mean = ((self.running_mean.as_tensor() * (1.0 - self.momentum))?
|
self.running_mean.set(&updated_running_mean)?;
|
||||||
+ (mean_x.flatten_all()? * self.momentum)?)?;
|
|
||||||
|
|
||||||
self.running_mean.set(&new_mean)?;
|
|
||||||
}
|
|
||||||
x.broadcast_sub(&mean_x)?
|
x.broadcast_sub(&mean_x)?
|
||||||
} else {
|
} else {
|
||||||
x
|
x
|
||||||
};
|
};
|
||||||
|
// The mean is taken over dim 1 as this is the batch dim after the transpose(0, 1) above.
|
||||||
let norm_x = x.sqr()?.mean_keepdim(1)?;
|
let norm_x = x.sqr()?.mean_keepdim(1)?;
|
||||||
{
|
let updated_running_var = {
|
||||||
// Update running variance
|
|
||||||
let batch_size = x.dim(1)? as f64;
|
let batch_size = x.dim(1)? as f64;
|
||||||
let running_var_weight = 1.0 - self.momentum;
|
let running_var_weight = 1.0 - self.momentum;
|
||||||
let norm_x_weight = self.momentum * batch_size / (batch_size - 1.0);
|
let norm_x_weight = self.momentum * batch_size / (batch_size - 1.0);
|
||||||
|
((self.running_var.as_tensor() * running_var_weight)?
|
||||||
let new_var = ((self.running_var.as_tensor() * running_var_weight)?
|
+ (&norm_x.flatten_all()? * norm_x_weight)?)?
|
||||||
+ (&norm_x.flatten_all()? * norm_x_weight)?)?;
|
};
|
||||||
|
self.running_var.set(&updated_running_var)?;
|
||||||
self.running_var.set(&new_var)?;
|
let x = x
|
||||||
}
|
.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?
|
||||||
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
|
.to_dtype(x_dtype)?;
|
||||||
let x = x_normed.to_dtype(x_dtype)?;
|
|
||||||
let x = match &self.weight_and_bias {
|
let x = match &self.weight_and_bias {
|
||||||
None => x,
|
None => x,
|
||||||
Some((weight, bias)) => {
|
Some((weight, bias)) => {
|
||||||
@ -297,6 +293,7 @@ pub fn batch_norm<C: Into<BatchNormConfig>>(
|
|||||||
config: C,
|
config: C,
|
||||||
vb: crate::VarBuilder,
|
vb: crate::VarBuilder,
|
||||||
) -> Result<BatchNorm> {
|
) -> Result<BatchNorm> {
|
||||||
|
use crate::Init;
|
||||||
let config = config.into();
|
let config = config.into();
|
||||||
if config.eps < 0. {
|
if config.eps < 0. {
|
||||||
candle::bail!("batch-norm eps cannot be negative {}", config.eps)
|
candle::bail!("batch-norm eps cannot be negative {}", config.eps)
|
||||||
|
Reference in New Issue
Block a user