mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
Add a batch normalization layer (#508)
* Add BatchNormalization. * More batch-norm. * Add some validation of the inputs. * More validation.
This commit is contained in:
154
candle-nn/src/batch_norm.rs
Normal file
154
candle-nn/src/batch_norm.rs
Normal file
@ -0,0 +1,154 @@
|
|||||||
|
//! Batch Normalization.
|
||||||
|
//!
|
||||||
|
//! This layer applies Batch Normalization over a mini-batch of inputs as described in [`Batch
|
||||||
|
//! Normalization`]. The input is expected to have at least three dimensions.
|
||||||
|
//!
|
||||||
|
//! Note that this implementation is for inference only, there is no possibility to track the
|
||||||
|
//! running stats.
|
||||||
|
//!
|
||||||
|
//! [`Batch Normalization`]: https://arxiv.org/abs/1502.03167
|
||||||
|
use candle::{DType, Result, Tensor};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||||
|
pub struct BatchNormConfig {
|
||||||
|
pub eps: f64,
|
||||||
|
pub remove_mean: bool,
|
||||||
|
/// The meaning of affine here is different from LayerNorm: when false there is no learnable
|
||||||
|
/// parameter at all, 1 used for gamma and 0 for beta.
|
||||||
|
pub affine: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for BatchNormConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
eps: 1e-5,
|
||||||
|
remove_mean: true,
|
||||||
|
affine: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<f64> for BatchNormConfig {
|
||||||
|
fn from(eps: f64) -> Self {
|
||||||
|
Self {
|
||||||
|
eps,
|
||||||
|
remove_mean: true,
|
||||||
|
affine: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct BatchNorm {
|
||||||
|
weight_and_bias: Option<(Tensor, Tensor)>,
|
||||||
|
remove_mean: bool,
|
||||||
|
eps: f64,
|
||||||
|
num_features: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BatchNorm {
|
||||||
|
pub fn new(num_features: usize, weight: Tensor, bias: Tensor, eps: f64) -> Result<Self> {
|
||||||
|
if eps < 0. {
|
||||||
|
candle::bail!("batch-norm eps cannot be negative {eps}")
|
||||||
|
}
|
||||||
|
if weight.dims() != [num_features] {
|
||||||
|
candle::bail!(
|
||||||
|
"batch-norm unexpected weight shape {:?} {num_features}",
|
||||||
|
weight.shape()
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if bias.dims() != [num_features] {
|
||||||
|
candle::bail!(
|
||||||
|
"batch-norm unexpected bias shape {:?} {num_features}",
|
||||||
|
bias.shape()
|
||||||
|
)
|
||||||
|
}
|
||||||
|
Ok(Self {
|
||||||
|
weight_and_bias: Some((weight, bias)),
|
||||||
|
remove_mean: true,
|
||||||
|
eps,
|
||||||
|
num_features,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn new_no_bias(num_features: usize, eps: f64) -> Result<Self> {
|
||||||
|
if eps < 0. {
|
||||||
|
candle::bail!("batch-norm eps cannot be negative {eps}")
|
||||||
|
}
|
||||||
|
Ok(Self {
|
||||||
|
weight_and_bias: None,
|
||||||
|
remove_mean: true,
|
||||||
|
eps,
|
||||||
|
num_features,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl crate::Module for BatchNorm {
|
||||||
|
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
|
let x_dtype = x.dtype();
|
||||||
|
let internal_dtype = match x_dtype {
|
||||||
|
DType::F16 | DType::BF16 => DType::F32,
|
||||||
|
d => d,
|
||||||
|
};
|
||||||
|
if x.rank() < 2 {
|
||||||
|
candle::bail!(
|
||||||
|
"batch-norm input tensor must have at least two dimensions ({:?})",
|
||||||
|
x.shape()
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if x.dim(1)? != self.num_features {
|
||||||
|
candle::bail!(
|
||||||
|
"batch-norm input doesn't have the expected number of features ({:?} <> {})",
|
||||||
|
x.shape(),
|
||||||
|
self.num_features
|
||||||
|
)
|
||||||
|
}
|
||||||
|
let x = x.to_dtype(internal_dtype)?;
|
||||||
|
let x = x.transpose(0, 1)?;
|
||||||
|
let x_dims_post_transpose = x.dims();
|
||||||
|
let x = x.flatten_from(1)?.contiguous()?;
|
||||||
|
let x = if self.remove_mean {
|
||||||
|
let mean_x = x.mean_keepdim(1)?;
|
||||||
|
x.broadcast_sub(&mean_x)?
|
||||||
|
} else {
|
||||||
|
x
|
||||||
|
};
|
||||||
|
let norm_x = x.sqr()?.mean_keepdim(1)?;
|
||||||
|
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
|
||||||
|
let x = x_normed.to_dtype(x_dtype)?;
|
||||||
|
let x = match &self.weight_and_bias {
|
||||||
|
None => x,
|
||||||
|
Some((weight, bias)) => {
|
||||||
|
let weight = weight.reshape((self.num_features, 1))?;
|
||||||
|
let bias = bias.reshape((self.num_features, 1))?;
|
||||||
|
x.broadcast_mul(&weight)?.broadcast_add(&bias)?
|
||||||
|
}
|
||||||
|
};
|
||||||
|
x.reshape(x_dims_post_transpose)?.transpose(0, 1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn batch_norm<C: Into<BatchNormConfig>>(
|
||||||
|
num_features: usize,
|
||||||
|
config: C,
|
||||||
|
vb: crate::VarBuilder,
|
||||||
|
) -> Result<BatchNorm> {
|
||||||
|
let config = config.into();
|
||||||
|
if config.eps < 0. {
|
||||||
|
candle::bail!("batch-norm eps cannot be negative {}", config.eps)
|
||||||
|
}
|
||||||
|
let weight_and_bias = if config.affine {
|
||||||
|
let weight = vb.get_or_init(num_features, "weight", crate::Init::Const(1.))?;
|
||||||
|
let bias = vb.get_or_init(num_features, "bias", crate::Init::Const(0.))?;
|
||||||
|
Some((weight, bias))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
Ok(BatchNorm {
|
||||||
|
weight_and_bias,
|
||||||
|
remove_mean: config.remove_mean,
|
||||||
|
eps: config.eps,
|
||||||
|
num_features,
|
||||||
|
})
|
||||||
|
}
|
@ -1,6 +1,7 @@
|
|||||||
use candle::{Result, Tensor};
|
use candle::{Result, Tensor};
|
||||||
|
|
||||||
pub mod activation;
|
pub mod activation;
|
||||||
|
pub mod batch_norm;
|
||||||
pub mod conv;
|
pub mod conv;
|
||||||
pub mod embedding;
|
pub mod embedding;
|
||||||
pub mod group_norm;
|
pub mod group_norm;
|
||||||
@ -13,6 +14,7 @@ pub mod optim;
|
|||||||
pub mod var_builder;
|
pub mod var_builder;
|
||||||
|
|
||||||
pub use activation::Activation;
|
pub use activation::Activation;
|
||||||
|
pub use batch_norm::{batch_norm, BatchNorm, BatchNormConfig};
|
||||||
pub use conv::{conv1d, conv2d, Conv1d, Conv1dConfig, Conv2d, Conv2dConfig};
|
pub use conv::{conv1d, conv2d, Conv1d, Conv1dConfig, Conv2d, Conv2dConfig};
|
||||||
pub use embedding::{embedding, Embedding};
|
pub use embedding::{embedding, Embedding};
|
||||||
pub use group_norm::{group_norm, GroupNorm};
|
pub use group_norm::{group_norm, GroupNorm};
|
||||||
|
70
candle-nn/tests/batch_norm.rs
Normal file
70
candle-nn/tests/batch_norm.rs
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
mod test_utils;
|
||||||
|
|
||||||
|
use anyhow::Result;
|
||||||
|
use candle::{Device, Tensor};
|
||||||
|
use candle_nn::{BatchNorm, Module};
|
||||||
|
|
||||||
|
/* The test below has been generated using the following PyTorch code:
|
||||||
|
import torch
|
||||||
|
torch.manual_seed(19551105)
|
||||||
|
m = torch.nn.BatchNorm2d(5, affine=False)
|
||||||
|
input = torch.randn(2, 5, 3, 4)
|
||||||
|
output = m(input)
|
||||||
|
print(input.flatten())
|
||||||
|
print(output.flatten())
|
||||||
|
*/
|
||||||
|
#[test]
|
||||||
|
fn batch_norm() -> Result<()> {
|
||||||
|
let bn = BatchNorm::new_no_bias(5, 1e-8)?;
|
||||||
|
let input: [f32; 120] = [
|
||||||
|
-0.7493, -1.0410, 1.6977, -0.6579, 1.7982, -0.0087, 0.2812, -0.1190, 0.2908, -0.5975,
|
||||||
|
-0.0278, -0.2138, -1.3130, -1.6048, -2.2028, 0.9452, 0.4002, 0.0831, 1.0004, 0.1860,
|
||||||
|
0.5004, 0.5539, 0.9991, -0.2540, -0.0703, -0.3752, -0.1096, -0.2374, 1.0258, -2.2208,
|
||||||
|
-0.0257, 0.6073, -1.1627, -0.0964, -1.9718, 1.6577, 0.1931, -0.3692, -0.8011, 0.9059,
|
||||||
|
0.4797, 0.6521, -0.0165, -0.6683, -0.4148, 2.0649, -0.8276, 1.7947, -0.2061, 0.5812,
|
||||||
|
-1.3598, 1.6192, 1.0466, -0.4423, 0.4202, 0.1749, 0.6969, 0.2616, -0.0369, -1.4951,
|
||||||
|
-0.0814, -0.1877, 0.0267, 0.6150, 0.2402, -1.1440, -2.0068, 0.6032, -2.6639, 0.8260,
|
||||||
|
0.1085, -0.1693, 1.2805, 0.7654, -0.4930, 0.3770, 1.1309, 0.2303, 0.2949, -0.2634, -0.5225,
|
||||||
|
0.4269, 0.6341, 1.5736, 0.9827, -1.2499, 0.3509, -1.6243, -0.8123, 0.7634, -0.3047, 0.0143,
|
||||||
|
-0.4032, 0.0537, 0.7022, 0.8405, -1.2221, -1.6847, -0.0714, -0.1608, 0.5579, -1.5858,
|
||||||
|
0.4617, -0.6480, 0.1332, 0.0419, -0.9784, 0.4173, 1.2313, -1.9046, -0.1656, 0.1259, 0.0763,
|
||||||
|
1.4252, -0.9115, -0.1093, -0.3100, -0.6734, -1.4357, 0.9205,
|
||||||
|
];
|
||||||
|
let input = Tensor::new(&input, &Device::Cpu)?.reshape((2, 5, 3, 4))?;
|
||||||
|
let output = bn.forward(&input)?;
|
||||||
|
assert_eq!(output.dims(), &[2, 5, 3, 4]);
|
||||||
|
let output = output.flatten_all()?;
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec1_round(&output, 4)?,
|
||||||
|
&[
|
||||||
|
-0.6391, -0.9414, 1.8965, -0.5444, 2.0007, 0.1283, 0.4287, 0.014, 0.4387, -0.4818,
|
||||||
|
0.1085, -0.0842, -1.6809, -2.0057, -2.6714, 0.8328, 0.2262, -0.1268, 0.8943, -0.0123,
|
||||||
|
0.3377, 0.3973, 0.8928, -0.5021, 0.0861, -0.2324, 0.0451, -0.0884, 1.2311, -2.1603,
|
||||||
|
0.1327, 0.7939, -1.055, 0.0589, -1.9002, 1.8912, 0.2918, -0.3253, -0.7993, 1.0741,
|
||||||
|
0.6063, 0.7955, 0.0617, -0.6536, -0.3754, 2.3461, -0.8284, 2.0495, -0.201, 0.6476,
|
||||||
|
-1.4446, 1.7665, 1.1493, -0.4556, 0.4741, 0.2097, 0.7723, 0.3031, -0.0186, -1.5905,
|
||||||
|
0.053, -0.0572, 0.165, 0.7746, 0.3862, -1.0481, -1.9422, 0.7624, -2.6231, 0.9933,
|
||||||
|
0.2498, -0.0381, 1.2061, 0.6327, -0.7681, 0.2004, 1.0396, 0.037, 0.109, -0.5125,
|
||||||
|
-0.8009, 0.2559, 0.4865, 1.5324, 1.1861, -1.1461, 0.5261, -1.5372, -0.689, 0.957,
|
||||||
|
-0.1587, 0.1745, -0.2616, 0.2156, 0.8931, 1.0375, -1.2614, -1.7691, 0.0015, -0.0966,
|
||||||
|
0.6921, -1.6605, 0.5866, -0.6313, 0.226, 0.1258, -0.9939, 0.5378, 1.3484, -2.0319,
|
||||||
|
-0.1574, 0.1568, 0.1034, 1.5574, -0.9614, -0.0967, -0.313, -0.7047, -1.5264, 1.0134
|
||||||
|
]
|
||||||
|
);
|
||||||
|
let bn2 = BatchNorm::new(
|
||||||
|
5,
|
||||||
|
Tensor::new(&[0.5f32], &Device::Cpu)?.broadcast_as(5)?,
|
||||||
|
Tensor::new(&[-1.5f32], &Device::Cpu)?.broadcast_as(5)?,
|
||||||
|
1e-8,
|
||||||
|
)?;
|
||||||
|
let output2 = bn2.forward(&input)?;
|
||||||
|
assert_eq!(output2.dims(), &[2, 5, 3, 4]);
|
||||||
|
let output2 = output2.flatten_all()?;
|
||||||
|
let diff2 = ((output2 - (output * 0.5)?)? + 1.5)?.sqr()?;
|
||||||
|
let sum_diff2 = diff2.sum_keepdim(0)?;
|
||||||
|
assert_eq!(test_utils::to_vec1_round(&sum_diff2, 4)?, &[0f32]);
|
||||||
|
Ok(())
|
||||||
|
}
|
Reference in New Issue
Block a user