diff --git a/candle-nn/src/batch_norm.rs b/candle-nn/src/batch_norm.rs new file mode 100644 index 00000000..d1aa2148 --- /dev/null +++ b/candle-nn/src/batch_norm.rs @@ -0,0 +1,154 @@ +//! Batch Normalization. +//! +//! This layer applies Batch Normalization over a mini-batch of inputs as described in [`Batch +//! Normalization`]. The input is expected to have at least three dimensions. +//! +//! Note that this implementation is for inference only, there is no possibility to track the +//! running stats. +//! +//! [`Batch Normalization`]: https://arxiv.org/abs/1502.03167 +use candle::{DType, Result, Tensor}; + +#[derive(Debug, Clone, Copy, PartialEq)] +pub struct BatchNormConfig { + pub eps: f64, + pub remove_mean: bool, + /// The meaning of affine here is different from LayerNorm: when false there is no learnable + /// parameter at all, 1 used for gamma and 0 for beta. + pub affine: bool, +} + +impl Default for BatchNormConfig { + fn default() -> Self { + Self { + eps: 1e-5, + remove_mean: true, + affine: true, + } + } +} + +impl From for BatchNormConfig { + fn from(eps: f64) -> Self { + Self { + eps, + remove_mean: true, + affine: true, + } + } +} + +#[derive(Debug)] +pub struct BatchNorm { + weight_and_bias: Option<(Tensor, Tensor)>, + remove_mean: bool, + eps: f64, + num_features: usize, +} + +impl BatchNorm { + pub fn new(num_features: usize, weight: Tensor, bias: Tensor, eps: f64) -> Result { + if eps < 0. { + candle::bail!("batch-norm eps cannot be negative {eps}") + } + if weight.dims() != [num_features] { + candle::bail!( + "batch-norm unexpected weight shape {:?} {num_features}", + weight.shape() + ) + } + if bias.dims() != [num_features] { + candle::bail!( + "batch-norm unexpected bias shape {:?} {num_features}", + bias.shape() + ) + } + Ok(Self { + weight_and_bias: Some((weight, bias)), + remove_mean: true, + eps, + num_features, + }) + } + + pub fn new_no_bias(num_features: usize, eps: f64) -> Result { + if eps < 0. { + candle::bail!("batch-norm eps cannot be negative {eps}") + } + Ok(Self { + weight_and_bias: None, + remove_mean: true, + eps, + num_features, + }) + } +} + +impl crate::Module for BatchNorm { + fn forward(&self, x: &Tensor) -> Result { + let x_dtype = x.dtype(); + let internal_dtype = match x_dtype { + DType::F16 | DType::BF16 => DType::F32, + d => d, + }; + if x.rank() < 2 { + candle::bail!( + "batch-norm input tensor must have at least two dimensions ({:?})", + x.shape() + ) + } + if x.dim(1)? != self.num_features { + candle::bail!( + "batch-norm input doesn't have the expected number of features ({:?} <> {})", + x.shape(), + self.num_features + ) + } + let x = x.to_dtype(internal_dtype)?; + let x = x.transpose(0, 1)?; + let x_dims_post_transpose = x.dims(); + let x = x.flatten_from(1)?.contiguous()?; + let x = if self.remove_mean { + let mean_x = x.mean_keepdim(1)?; + x.broadcast_sub(&mean_x)? + } else { + x + }; + let norm_x = x.sqr()?.mean_keepdim(1)?; + let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?; + let x = x_normed.to_dtype(x_dtype)?; + let x = match &self.weight_and_bias { + None => x, + Some((weight, bias)) => { + let weight = weight.reshape((self.num_features, 1))?; + let bias = bias.reshape((self.num_features, 1))?; + x.broadcast_mul(&weight)?.broadcast_add(&bias)? + } + }; + x.reshape(x_dims_post_transpose)?.transpose(0, 1) + } +} + +pub fn batch_norm>( + num_features: usize, + config: C, + vb: crate::VarBuilder, +) -> Result { + let config = config.into(); + if config.eps < 0. { + candle::bail!("batch-norm eps cannot be negative {}", config.eps) + } + let weight_and_bias = if config.affine { + let weight = vb.get_or_init(num_features, "weight", crate::Init::Const(1.))?; + let bias = vb.get_or_init(num_features, "bias", crate::Init::Const(0.))?; + Some((weight, bias)) + } else { + None + }; + Ok(BatchNorm { + weight_and_bias, + remove_mean: config.remove_mean, + eps: config.eps, + num_features, + }) +} diff --git a/candle-nn/src/lib.rs b/candle-nn/src/lib.rs index da63d592..7b486f71 100644 --- a/candle-nn/src/lib.rs +++ b/candle-nn/src/lib.rs @@ -1,6 +1,7 @@ use candle::{Result, Tensor}; pub mod activation; +pub mod batch_norm; pub mod conv; pub mod embedding; pub mod group_norm; @@ -13,6 +14,7 @@ pub mod optim; pub mod var_builder; pub use activation::Activation; +pub use batch_norm::{batch_norm, BatchNorm, BatchNormConfig}; pub use conv::{conv1d, conv2d, Conv1d, Conv1dConfig, Conv2d, Conv2dConfig}; pub use embedding::{embedding, Embedding}; pub use group_norm::{group_norm, GroupNorm}; diff --git a/candle-nn/tests/batch_norm.rs b/candle-nn/tests/batch_norm.rs new file mode 100644 index 00000000..1575e914 --- /dev/null +++ b/candle-nn/tests/batch_norm.rs @@ -0,0 +1,70 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +mod test_utils; + +use anyhow::Result; +use candle::{Device, Tensor}; +use candle_nn::{BatchNorm, Module}; + +/* The test below has been generated using the following PyTorch code: +import torch +torch.manual_seed(19551105) +m = torch.nn.BatchNorm2d(5, affine=False) +input = torch.randn(2, 5, 3, 4) +output = m(input) +print(input.flatten()) +print(output.flatten()) +*/ +#[test] +fn batch_norm() -> Result<()> { + let bn = BatchNorm::new_no_bias(5, 1e-8)?; + let input: [f32; 120] = [ + -0.7493, -1.0410, 1.6977, -0.6579, 1.7982, -0.0087, 0.2812, -0.1190, 0.2908, -0.5975, + -0.0278, -0.2138, -1.3130, -1.6048, -2.2028, 0.9452, 0.4002, 0.0831, 1.0004, 0.1860, + 0.5004, 0.5539, 0.9991, -0.2540, -0.0703, -0.3752, -0.1096, -0.2374, 1.0258, -2.2208, + -0.0257, 0.6073, -1.1627, -0.0964, -1.9718, 1.6577, 0.1931, -0.3692, -0.8011, 0.9059, + 0.4797, 0.6521, -0.0165, -0.6683, -0.4148, 2.0649, -0.8276, 1.7947, -0.2061, 0.5812, + -1.3598, 1.6192, 1.0466, -0.4423, 0.4202, 0.1749, 0.6969, 0.2616, -0.0369, -1.4951, + -0.0814, -0.1877, 0.0267, 0.6150, 0.2402, -1.1440, -2.0068, 0.6032, -2.6639, 0.8260, + 0.1085, -0.1693, 1.2805, 0.7654, -0.4930, 0.3770, 1.1309, 0.2303, 0.2949, -0.2634, -0.5225, + 0.4269, 0.6341, 1.5736, 0.9827, -1.2499, 0.3509, -1.6243, -0.8123, 0.7634, -0.3047, 0.0143, + -0.4032, 0.0537, 0.7022, 0.8405, -1.2221, -1.6847, -0.0714, -0.1608, 0.5579, -1.5858, + 0.4617, -0.6480, 0.1332, 0.0419, -0.9784, 0.4173, 1.2313, -1.9046, -0.1656, 0.1259, 0.0763, + 1.4252, -0.9115, -0.1093, -0.3100, -0.6734, -1.4357, 0.9205, + ]; + let input = Tensor::new(&input, &Device::Cpu)?.reshape((2, 5, 3, 4))?; + let output = bn.forward(&input)?; + assert_eq!(output.dims(), &[2, 5, 3, 4]); + let output = output.flatten_all()?; + assert_eq!( + test_utils::to_vec1_round(&output, 4)?, + &[ + -0.6391, -0.9414, 1.8965, -0.5444, 2.0007, 0.1283, 0.4287, 0.014, 0.4387, -0.4818, + 0.1085, -0.0842, -1.6809, -2.0057, -2.6714, 0.8328, 0.2262, -0.1268, 0.8943, -0.0123, + 0.3377, 0.3973, 0.8928, -0.5021, 0.0861, -0.2324, 0.0451, -0.0884, 1.2311, -2.1603, + 0.1327, 0.7939, -1.055, 0.0589, -1.9002, 1.8912, 0.2918, -0.3253, -0.7993, 1.0741, + 0.6063, 0.7955, 0.0617, -0.6536, -0.3754, 2.3461, -0.8284, 2.0495, -0.201, 0.6476, + -1.4446, 1.7665, 1.1493, -0.4556, 0.4741, 0.2097, 0.7723, 0.3031, -0.0186, -1.5905, + 0.053, -0.0572, 0.165, 0.7746, 0.3862, -1.0481, -1.9422, 0.7624, -2.6231, 0.9933, + 0.2498, -0.0381, 1.2061, 0.6327, -0.7681, 0.2004, 1.0396, 0.037, 0.109, -0.5125, + -0.8009, 0.2559, 0.4865, 1.5324, 1.1861, -1.1461, 0.5261, -1.5372, -0.689, 0.957, + -0.1587, 0.1745, -0.2616, 0.2156, 0.8931, 1.0375, -1.2614, -1.7691, 0.0015, -0.0966, + 0.6921, -1.6605, 0.5866, -0.6313, 0.226, 0.1258, -0.9939, 0.5378, 1.3484, -2.0319, + -0.1574, 0.1568, 0.1034, 1.5574, -0.9614, -0.0967, -0.313, -0.7047, -1.5264, 1.0134 + ] + ); + let bn2 = BatchNorm::new( + 5, + Tensor::new(&[0.5f32], &Device::Cpu)?.broadcast_as(5)?, + Tensor::new(&[-1.5f32], &Device::Cpu)?.broadcast_as(5)?, + 1e-8, + )?; + let output2 = bn2.forward(&input)?; + assert_eq!(output2.dims(), &[2, 5, 3, 4]); + let output2 = output2.flatten_all()?; + let diff2 = ((output2 - (output * 0.5)?)? + 1.5)?.sqr()?; + let sum_diff2 = diff2.sum_keepdim(0)?; + assert_eq!(test_utils::to_vec1_round(&sum_diff2, 4)?, &[0f32]); + Ok(()) +}