mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 12:20:46 +00:00
Do not implement Module for BatchNorm. (#1513)
This commit is contained in:
@ -40,8 +40,8 @@ fn block(dim: usize, kernel_size: usize, vb: VarBuilder) -> Result<impl Module>
|
||||
let conv2 = candle_nn::conv2d(dim, dim, 1, Default::default(), vb.pp(1))?;
|
||||
let bn2 = batch_norm(dim, 1e-5, vb.pp(3))?;
|
||||
Ok(candle_nn::func(move |xs| {
|
||||
let ys = xs.apply(&conv1)?.gelu_erf()?.apply(&bn1)?;
|
||||
(xs + ys)?.apply(&conv2)?.gelu_erf()?.apply(&bn2)
|
||||
let ys = xs.apply(&conv1)?.gelu_erf()?.apply_t(&bn1, false)?;
|
||||
(xs + ys)?.apply(&conv2)?.gelu_erf()?.apply_t(&bn2, false)
|
||||
}))
|
||||
}
|
||||
|
||||
@ -64,7 +64,7 @@ fn convmixer(
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let fc = candle_nn::linear(dim, nclasses, vb.pp(25))?;
|
||||
Ok(candle_nn::func(move |xs| {
|
||||
let mut xs = xs.apply(&conv1)?.gelu_erf()?.apply(&bn1)?;
|
||||
let mut xs = xs.apply(&conv1)?.gelu_erf()?.apply_t(&bn1, false)?;
|
||||
for block in blocks.iter() {
|
||||
xs = xs.apply(block)?
|
||||
}
|
||||
|
Reference in New Issue
Block a user