mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Simplify usage of the pool functions. (#662)
* Simplify usage of the pool functions. * Small tweak. * Attempt at using apply to simplify the convnet definition.
This commit is contained in:
@ -83,13 +83,15 @@ impl Model for ConvNet {
|
||||
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let (b_sz, _img_dim) = xs.dims2()?;
|
||||
let xs = xs.reshape((b_sz, 1, 28, 28))?;
|
||||
let xs = self.conv1.forward(&xs)?.max_pool2d((2, 2), (2, 2))?;
|
||||
let xs = self.conv2.forward(&xs)?.max_pool2d((2, 2), (2, 2))?;
|
||||
let xs = xs.flatten_from(1)?;
|
||||
let xs = self.fc1.forward(&xs)?;
|
||||
let xs = xs.relu()?;
|
||||
self.fc2.forward(&xs)
|
||||
xs.reshape((b_sz, 1, 28, 28))?
|
||||
.apply(&self.conv1)?
|
||||
.max_pool2d(2)?
|
||||
.apply(&self.conv2)?
|
||||
.max_pool2d(2)?
|
||||
.flatten_from(1)?
|
||||
.apply(&self.fc1)?
|
||||
.relu()?
|
||||
.apply(&self.fc2)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -47,7 +47,7 @@ impl Downsample2D {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
match &self.conv {
|
||||
None => xs.avg_pool2d((2, 2), (2, 2)),
|
||||
None => xs.avg_pool2d(2),
|
||||
Some(conv) => {
|
||||
if self.padding == 0 {
|
||||
let xs = xs
|
||||
|
@ -198,15 +198,15 @@ impl Module for Sppf {
|
||||
let xs2 = xs
|
||||
.pad_with_zeros(2, self.k / 2, self.k / 2)?
|
||||
.pad_with_zeros(3, self.k / 2, self.k / 2)?
|
||||
.max_pool2d((self.k, self.k), (1, 1))?;
|
||||
.max_pool2d_with_stride(self.k, 1)?;
|
||||
let xs3 = xs2
|
||||
.pad_with_zeros(2, self.k / 2, self.k / 2)?
|
||||
.pad_with_zeros(3, self.k / 2, self.k / 2)?
|
||||
.max_pool2d((self.k, self.k), (1, 1))?;
|
||||
.max_pool2d_with_stride(self.k, 1)?;
|
||||
let xs4 = xs3
|
||||
.pad_with_zeros(2, self.k / 2, self.k / 2)?
|
||||
.pad_with_zeros(3, self.k / 2, self.k / 2)?
|
||||
.max_pool2d((self.k, self.k), (1, 1))?;
|
||||
.max_pool2d_with_stride(self.k, 1)?;
|
||||
self.cv2.forward(&Tensor::cat(&[&xs, &xs2, &xs3, &xs4], 1)?)
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user