mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 20:22:49 +00:00
Do not implement Module for BatchNorm. (#1513)
This commit is contained in:
@ -39,7 +39,7 @@ fn batch_norm() -> Result<()> {
|
||||
1.4252, -0.9115, -0.1093, -0.3100, -0.6734, -1.4357, 0.9205,
|
||||
];
|
||||
let input = Tensor::new(&input, &Device::Cpu)?.reshape((2, 5, 3, 4))?;
|
||||
let output = bn.forward_learning(&input)?;
|
||||
let output = bn.forward_train(&input)?;
|
||||
assert_eq!(output.dims(), &[2, 5, 3, 4]);
|
||||
let output = output.flatten_all()?;
|
||||
assert_eq!(
|
||||
@ -67,7 +67,7 @@ fn batch_norm() -> Result<()> {
|
||||
Tensor::new(&[-1.5f32], &Device::Cpu)?.broadcast_as(5)?,
|
||||
1e-8,
|
||||
)?;
|
||||
let output2 = bn2.forward_learning(&input)?;
|
||||
let output2 = bn2.forward_train(&input)?;
|
||||
assert_eq!(output2.dims(), &[2, 5, 3, 4]);
|
||||
let output2 = output2.flatten_all()?;
|
||||
let diff2 = ((output2 - (output * 0.5)?)? + 1.5)?.sqr()?;
|
||||
|
Reference in New Issue
Block a user