Do not implement Module for BatchNorm. (#1513)

This commit is contained in:
Laurent Mazare
2024-01-01 10:13:13 +01:00
committed by GitHub
parent 1fb2dd905c
commit b0fe5e4453
9 changed files with 31 additions and 33 deletions

View File

@ -39,7 +39,7 @@ fn batch_norm() -> Result<()> {
1.4252, -0.9115, -0.1093, -0.3100, -0.6734, -1.4357, 0.9205,
];
let input = Tensor::new(&input, &Device::Cpu)?.reshape((2, 5, 3, 4))?;
let output = bn.forward_learning(&input)?;
let output = bn.forward_train(&input)?;
assert_eq!(output.dims(), &[2, 5, 3, 4]);
let output = output.flatten_all()?;
assert_eq!(
@ -67,7 +67,7 @@ fn batch_norm() -> Result<()> {
Tensor::new(&[-1.5f32], &Device::Cpu)?.broadcast_as(5)?,
1e-8,
)?;
let output2 = bn2.forward_learning(&input)?;
let output2 = bn2.forward_train(&input)?;
assert_eq!(output2.dims(), &[2, 5, 3, 4]);
let output2 = output2.flatten_all()?;
let diff2 = ((output2 - (output * 0.5)?)? + 1.5)?.sqr()?;