Simplify usage of the pool functions. (#662)

* Simplify usage of the pool functions.

* Small tweak.

* Attempt at using apply to simplify the convnet definition.
This commit is contained in:
Laurent Mazare
2023-08-29 19:12:16 +01:00
committed by GitHub
parent b31d41e26a
commit 2d3fcad267
9 changed files with 86 additions and 42 deletions

View File

@ -83,13 +83,15 @@ impl Model for ConvNet {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let (b_sz, _img_dim) = xs.dims2()?;
let xs = xs.reshape((b_sz, 1, 28, 28))?;
let xs = self.conv1.forward(&xs)?.max_pool2d((2, 2), (2, 2))?;
let xs = self.conv2.forward(&xs)?.max_pool2d((2, 2), (2, 2))?;
let xs = xs.flatten_from(1)?;
let xs = self.fc1.forward(&xs)?;
let xs = xs.relu()?;
self.fc2.forward(&xs)
xs.reshape((b_sz, 1, 28, 28))?
.apply(&self.conv1)?
.max_pool2d(2)?
.apply(&self.conv2)?
.max_pool2d(2)?
.flatten_from(1)?
.apply(&self.fc1)?
.relu()?
.apply(&self.fc2)
}
}