mirror of
https://github.com/huggingface/candle.git
synced 2025-06-22 20:38:06 +00:00
[Breaking] Add training to batchnorm with exponential moving average (#1504)
* Add training to batchnorm with exponential moving average * Add more checks to batch norm * Resolve some review comments * Add with_momentum varients of `new` methods * Add check for range of momentum variable; update batch norm test * Run cargo fmt * Add back num_features parameter * Format; tiny simplification
This commit is contained in:
@ -16,6 +16,8 @@ input = torch.randn(2, 5, 3, 4)
|
||||
output = m(input)
|
||||
print(input.flatten())
|
||||
print(output.flatten())
|
||||
print(m.running_mean)
|
||||
print(m.running_var)
|
||||
*/
|
||||
#[test]
|
||||
fn batch_norm() -> Result<()> {
|
||||
@ -71,5 +73,14 @@ fn batch_norm() -> Result<()> {
|
||||
let diff2 = ((output2 - (output * 0.5)?)? + 1.5)?.sqr()?;
|
||||
let sum_diff2 = diff2.sum_keepdim(0)?;
|
||||
assert_eq!(test_utils::to_vec1_round(&sum_diff2, 4)?, &[0f32]);
|
||||
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(bn.running_mean(), 4)?,
|
||||
&[-0.0133, 0.0197, -0.0153, -0.0073, -0.0020]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(bn.running_var(), 4)?,
|
||||
&[0.9972, 0.9842, 0.9956, 0.9866, 0.9898]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
Reference in New Issue
Block a user