Add a batch normalization layer (#508)

* Add BatchNormalization.

* More batch-norm.

* Add some validation of the inputs.

* More validation.
This commit is contained in:
Laurent Mazare
2023-08-18 20:05:56 +01:00
committed by GitHub
parent b64e782c2d
commit 42e1cc8062
3 changed files with 226 additions and 0 deletions

154
candle-nn/src/batch_norm.rs Normal file
View File

@ -0,0 +1,154 @@
//! Batch Normalization.
//!
//! This layer applies Batch Normalization over a mini-batch of inputs as described in [`Batch
//! Normalization`]. The input is expected to have at least three dimensions.
//!
//! Note that this implementation is for inference only, there is no possibility to track the
//! running stats.
//!
//! [`Batch Normalization`]: https://arxiv.org/abs/1502.03167
use candle::{DType, Result, Tensor};
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct BatchNormConfig {
pub eps: f64,
pub remove_mean: bool,
/// The meaning of affine here is different from LayerNorm: when false there is no learnable
/// parameter at all, 1 used for gamma and 0 for beta.
pub affine: bool,
}
impl Default for BatchNormConfig {
fn default() -> Self {
Self {
eps: 1e-5,
remove_mean: true,
affine: true,
}
}
}
impl From<f64> for BatchNormConfig {
fn from(eps: f64) -> Self {
Self {
eps,
remove_mean: true,
affine: true,
}
}
}
#[derive(Debug)]
pub struct BatchNorm {
weight_and_bias: Option<(Tensor, Tensor)>,
remove_mean: bool,
eps: f64,
num_features: usize,
}
impl BatchNorm {
pub fn new(num_features: usize, weight: Tensor, bias: Tensor, eps: f64) -> Result<Self> {
if eps < 0. {
candle::bail!("batch-norm eps cannot be negative {eps}")
}
if weight.dims() != [num_features] {
candle::bail!(
"batch-norm unexpected weight shape {:?} {num_features}",
weight.shape()
)
}
if bias.dims() != [num_features] {
candle::bail!(
"batch-norm unexpected bias shape {:?} {num_features}",
bias.shape()
)
}
Ok(Self {
weight_and_bias: Some((weight, bias)),
remove_mean: true,
eps,
num_features,
})
}
pub fn new_no_bias(num_features: usize, eps: f64) -> Result<Self> {
if eps < 0. {
candle::bail!("batch-norm eps cannot be negative {eps}")
}
Ok(Self {
weight_and_bias: None,
remove_mean: true,
eps,
num_features,
})
}
}
impl crate::Module for BatchNorm {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x_dtype = x.dtype();
let internal_dtype = match x_dtype {
DType::F16 | DType::BF16 => DType::F32,
d => d,
};
if x.rank() < 2 {
candle::bail!(
"batch-norm input tensor must have at least two dimensions ({:?})",
x.shape()
)
}
if x.dim(1)? != self.num_features {
candle::bail!(
"batch-norm input doesn't have the expected number of features ({:?} <> {})",
x.shape(),
self.num_features
)
}
let x = x.to_dtype(internal_dtype)?;
let x = x.transpose(0, 1)?;
let x_dims_post_transpose = x.dims();
let x = x.flatten_from(1)?.contiguous()?;
let x = if self.remove_mean {
let mean_x = x.mean_keepdim(1)?;
x.broadcast_sub(&mean_x)?
} else {
x
};
let norm_x = x.sqr()?.mean_keepdim(1)?;
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
let x = x_normed.to_dtype(x_dtype)?;
let x = match &self.weight_and_bias {
None => x,
Some((weight, bias)) => {
let weight = weight.reshape((self.num_features, 1))?;
let bias = bias.reshape((self.num_features, 1))?;
x.broadcast_mul(&weight)?.broadcast_add(&bias)?
}
};
x.reshape(x_dims_post_transpose)?.transpose(0, 1)
}
}
pub fn batch_norm<C: Into<BatchNormConfig>>(
num_features: usize,
config: C,
vb: crate::VarBuilder,
) -> Result<BatchNorm> {
let config = config.into();
if config.eps < 0. {
candle::bail!("batch-norm eps cannot be negative {}", config.eps)
}
let weight_and_bias = if config.affine {
let weight = vb.get_or_init(num_features, "weight", crate::Init::Const(1.))?;
let bias = vb.get_or_init(num_features, "bias", crate::Init::Const(0.))?;
Some((weight, bias))
} else {
None
};
Ok(BatchNorm {
weight_and_bias,
remove_mean: config.remove_mean,
eps: config.eps,
num_features,
})
}

View File

@ -1,6 +1,7 @@
use candle::{Result, Tensor}; use candle::{Result, Tensor};
pub mod activation; pub mod activation;
pub mod batch_norm;
pub mod conv; pub mod conv;
pub mod embedding; pub mod embedding;
pub mod group_norm; pub mod group_norm;
@ -13,6 +14,7 @@ pub mod optim;
pub mod var_builder; pub mod var_builder;
pub use activation::Activation; pub use activation::Activation;
pub use batch_norm::{batch_norm, BatchNorm, BatchNormConfig};
pub use conv::{conv1d, conv2d, Conv1d, Conv1dConfig, Conv2d, Conv2dConfig}; pub use conv::{conv1d, conv2d, Conv1d, Conv1dConfig, Conv2d, Conv2dConfig};
pub use embedding::{embedding, Embedding}; pub use embedding::{embedding, Embedding};
pub use group_norm::{group_norm, GroupNorm}; pub use group_norm::{group_norm, GroupNorm};

View File

@ -0,0 +1,70 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
mod test_utils;
use anyhow::Result;
use candle::{Device, Tensor};
use candle_nn::{BatchNorm, Module};
/* The test below has been generated using the following PyTorch code:
import torch
torch.manual_seed(19551105)
m = torch.nn.BatchNorm2d(5, affine=False)
input = torch.randn(2, 5, 3, 4)
output = m(input)
print(input.flatten())
print(output.flatten())
*/
#[test]
fn batch_norm() -> Result<()> {
let bn = BatchNorm::new_no_bias(5, 1e-8)?;
let input: [f32; 120] = [
-0.7493, -1.0410, 1.6977, -0.6579, 1.7982, -0.0087, 0.2812, -0.1190, 0.2908, -0.5975,
-0.0278, -0.2138, -1.3130, -1.6048, -2.2028, 0.9452, 0.4002, 0.0831, 1.0004, 0.1860,
0.5004, 0.5539, 0.9991, -0.2540, -0.0703, -0.3752, -0.1096, -0.2374, 1.0258, -2.2208,
-0.0257, 0.6073, -1.1627, -0.0964, -1.9718, 1.6577, 0.1931, -0.3692, -0.8011, 0.9059,
0.4797, 0.6521, -0.0165, -0.6683, -0.4148, 2.0649, -0.8276, 1.7947, -0.2061, 0.5812,
-1.3598, 1.6192, 1.0466, -0.4423, 0.4202, 0.1749, 0.6969, 0.2616, -0.0369, -1.4951,
-0.0814, -0.1877, 0.0267, 0.6150, 0.2402, -1.1440, -2.0068, 0.6032, -2.6639, 0.8260,
0.1085, -0.1693, 1.2805, 0.7654, -0.4930, 0.3770, 1.1309, 0.2303, 0.2949, -0.2634, -0.5225,
0.4269, 0.6341, 1.5736, 0.9827, -1.2499, 0.3509, -1.6243, -0.8123, 0.7634, -0.3047, 0.0143,
-0.4032, 0.0537, 0.7022, 0.8405, -1.2221, -1.6847, -0.0714, -0.1608, 0.5579, -1.5858,
0.4617, -0.6480, 0.1332, 0.0419, -0.9784, 0.4173, 1.2313, -1.9046, -0.1656, 0.1259, 0.0763,
1.4252, -0.9115, -0.1093, -0.3100, -0.6734, -1.4357, 0.9205,
];
let input = Tensor::new(&input, &Device::Cpu)?.reshape((2, 5, 3, 4))?;
let output = bn.forward(&input)?;
assert_eq!(output.dims(), &[2, 5, 3, 4]);
let output = output.flatten_all()?;
assert_eq!(
test_utils::to_vec1_round(&output, 4)?,
&[
-0.6391, -0.9414, 1.8965, -0.5444, 2.0007, 0.1283, 0.4287, 0.014, 0.4387, -0.4818,
0.1085, -0.0842, -1.6809, -2.0057, -2.6714, 0.8328, 0.2262, -0.1268, 0.8943, -0.0123,
0.3377, 0.3973, 0.8928, -0.5021, 0.0861, -0.2324, 0.0451, -0.0884, 1.2311, -2.1603,
0.1327, 0.7939, -1.055, 0.0589, -1.9002, 1.8912, 0.2918, -0.3253, -0.7993, 1.0741,
0.6063, 0.7955, 0.0617, -0.6536, -0.3754, 2.3461, -0.8284, 2.0495, -0.201, 0.6476,
-1.4446, 1.7665, 1.1493, -0.4556, 0.4741, 0.2097, 0.7723, 0.3031, -0.0186, -1.5905,
0.053, -0.0572, 0.165, 0.7746, 0.3862, -1.0481, -1.9422, 0.7624, -2.6231, 0.9933,
0.2498, -0.0381, 1.2061, 0.6327, -0.7681, 0.2004, 1.0396, 0.037, 0.109, -0.5125,
-0.8009, 0.2559, 0.4865, 1.5324, 1.1861, -1.1461, 0.5261, -1.5372, -0.689, 0.957,
-0.1587, 0.1745, -0.2616, 0.2156, 0.8931, 1.0375, -1.2614, -1.7691, 0.0015, -0.0966,
0.6921, -1.6605, 0.5866, -0.6313, 0.226, 0.1258, -0.9939, 0.5378, 1.3484, -2.0319,
-0.1574, 0.1568, 0.1034, 1.5574, -0.9614, -0.0967, -0.313, -0.7047, -1.5264, 1.0134
]
);
let bn2 = BatchNorm::new(
5,
Tensor::new(&[0.5f32], &Device::Cpu)?.broadcast_as(5)?,
Tensor::new(&[-1.5f32], &Device::Cpu)?.broadcast_as(5)?,
1e-8,
)?;
let output2 = bn2.forward(&input)?;
assert_eq!(output2.dims(), &[2, 5, 3, 4]);
let output2 = output2.flatten_all()?;
let diff2 = ((output2 - (output * 0.5)?)? + 1.5)?.sqr()?;
let sum_diff2 = diff2.sum_keepdim(0)?;
assert_eq!(test_utils::to_vec1_round(&sum_diff2, 4)?, &[0f32]);
Ok(())
}