Do not implement Module for BatchNorm. (#1513)

This commit is contained in:
Laurent Mazare
2024-01-01 10:13:13 +01:00
committed by GitHub
parent 1fb2dd905c
commit b0fe5e4453
9 changed files with 31 additions and 33 deletions

View File

@ -40,8 +40,8 @@ fn block(dim: usize, kernel_size: usize, vb: VarBuilder) -> Result<impl Module>
let conv2 = candle_nn::conv2d(dim, dim, 1, Default::default(), vb.pp(1))?;
let bn2 = batch_norm(dim, 1e-5, vb.pp(3))?;
Ok(candle_nn::func(move |xs| {
let ys = xs.apply(&conv1)?.gelu_erf()?.apply(&bn1)?;
(xs + ys)?.apply(&conv2)?.gelu_erf()?.apply(&bn2)
let ys = xs.apply(&conv1)?.gelu_erf()?.apply_t(&bn1, false)?;
(xs + ys)?.apply(&conv2)?.gelu_erf()?.apply_t(&bn2, false)
}))
}
@ -64,7 +64,7 @@ fn convmixer(
.collect::<Result<Vec<_>>>()?;
let fc = candle_nn::linear(dim, nclasses, vb.pp(25))?;
Ok(candle_nn::func(move |xs| {
let mut xs = xs.apply(&conv1)?.gelu_erf()?.apply(&bn1)?;
let mut xs = xs.apply(&conv1)?.gelu_erf()?.apply_t(&bn1, false)?;
for block in blocks.iter() {
xs = xs.apply(block)?
}