mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 04:00:28 +00:00
Do not implement Module for BatchNorm. (#1513)
This commit is contained in:
@ -25,7 +25,7 @@ fn downsample(c_in: usize, c_out: usize, stride: usize, vb: VarBuilder) -> Resul
|
||||
if stride != 1 || c_in != c_out {
|
||||
let conv = conv2d(c_in, c_out, 1, 0, stride, vb.pp(0))?;
|
||||
let bn = batch_norm(c_out, 1e-5, vb.pp(1))?;
|
||||
Ok(Func::new(move |xs| xs.apply(&conv)?.apply(&bn)))
|
||||
Ok(Func::new(move |xs| xs.apply(&conv)?.apply_t(&bn, false)))
|
||||
} else {
|
||||
Ok(Func::new(|xs| Ok(xs.clone())))
|
||||
}
|
||||
@ -40,10 +40,10 @@ fn basic_block(c_in: usize, c_out: usize, stride: usize, vb: VarBuilder) -> Resu
|
||||
Ok(Func::new(move |xs| {
|
||||
let ys = xs
|
||||
.apply(&conv1)?
|
||||
.apply(&bn1)?
|
||||
.apply_t(&bn1, false)?
|
||||
.relu()?
|
||||
.apply(&conv2)?
|
||||
.apply(&bn2)?;
|
||||
.apply_t(&bn2, false)?;
|
||||
(xs.apply(&downsample)? + ys)?.relu()
|
||||
}))
|
||||
}
|
||||
@ -94,7 +94,7 @@ fn resnet(
|
||||
Ok(Func::new(move |xs| {
|
||||
let xs = xs
|
||||
.apply(&conv1)?
|
||||
.apply(&bn1)?
|
||||
.apply_t(&bn1, false)?
|
||||
.relu()?
|
||||
.pad_with_same(D::Minus1, 1, 1)?
|
||||
.pad_with_same(D::Minus2, 1, 1)?
|
||||
@ -149,13 +149,13 @@ fn bottleneck_block(
|
||||
Ok(Func::new(move |xs| {
|
||||
let ys = xs
|
||||
.apply(&conv1)?
|
||||
.apply(&bn1)?
|
||||
.apply_t(&bn1, false)?
|
||||
.relu()?
|
||||
.apply(&conv2)?
|
||||
.apply(&bn2)?
|
||||
.apply_t(&bn2, false)?
|
||||
.relu()?
|
||||
.apply(&conv3)?
|
||||
.apply(&bn3)?;
|
||||
.apply_t(&bn3, false)?;
|
||||
(xs.apply(&downsample)? + ys)?.relu()
|
||||
}))
|
||||
}
|
||||
@ -206,7 +206,7 @@ fn bottleneck_resnet(
|
||||
Ok(Func::new(move |xs| {
|
||||
let xs = xs
|
||||
.apply(&conv1)?
|
||||
.apply(&bn1)?
|
||||
.apply_t(&bn1, false)?
|
||||
.relu()?
|
||||
.pad_with_same(D::Minus1, 1, 1)?
|
||||
.pad_with_same(D::Minus2, 1, 1)?
|
||||
|
Reference in New Issue
Block a user