[Breaking] Add training to batchnorm with exponential moving average (#1504)

* Add training to batchnorm with exponential moving average

* Add more checks to batch norm

* Resolve some review comments

* Add with_momentum varients of `new` methods

* Add check for range of momentum variable; update batch norm test

* Run cargo fmt

* Add back num_features parameter

* Format; tiny simplification
This commit is contained in:
nkoppel
2023-12-30 15:42:08 +00:00
committed by GitHub
parent 51e577a682
commit 4290b81244
2 changed files with 169 additions and 50 deletions

View File

@ -16,6 +16,8 @@ input = torch.randn(2, 5, 3, 4)
output = m(input)
print(input.flatten())
print(output.flatten())
print(m.running_mean)
print(m.running_var)
*/
#[test]
fn batch_norm() -> Result<()> {
@ -71,5 +73,14 @@ fn batch_norm() -> Result<()> {
let diff2 = ((output2 - (output * 0.5)?)? + 1.5)?.sqr()?;
let sum_diff2 = diff2.sum_keepdim(0)?;
assert_eq!(test_utils::to_vec1_round(&sum_diff2, 4)?, &[0f32]);
assert_eq!(
test_utils::to_vec1_round(bn.running_mean(), 4)?,
&[-0.0133, 0.0197, -0.0153, -0.0073, -0.0020]
);
assert_eq!(
test_utils::to_vec1_round(bn.running_var(), 4)?,
&[0.9972, 0.9842, 0.9956, 0.9866, 0.9898]
);
Ok(())
}