Skeleton for the avg-pool2d and upsample-nearest2d ops. (#337)

* Skeleton for the avg-pool2d and upsample-nearest2d ops.

* Preliminary conv2d support.
This commit is contained in:
Laurent Mazare
2023-08-07 17:15:38 +02:00
committed by GitHub
parent f53a333ea9
commit 2345b8ce3f
7 changed files with 88 additions and 17 deletions

View File

@ -36,7 +36,7 @@ impl Downsample2D {
impl Downsample2D {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
match &self.conv {
None => crate::utils::avg_pool2d(xs), // [2, 2], [2, 2], [0, 0], false, true, None),
None => xs.avg_pool2d((2, 2), (2, 2)),
Some(conv) => {
if self.padding == 0 {
let xs = xs
@ -72,13 +72,10 @@ impl Upsample2D {
fn forward(&self, xs: &Tensor, size: Option<(usize, usize)>) -> Result<Tensor> {
let xs = match size {
None => {
// The following does not work and it's tricky to pass no fixed
// dimensions so hack our way around this.
// xs.upsample_nearest2d(&[], Some(2.), Some(2.)
let (_bsize, _channels, _h, _w) = xs.dims4()?;
crate::utils::upsample_nearest2d(xs)? // [2 * h, 2 * w], Some(2.), Some(2.))
let (_bsize, _channels, h, w) = xs.dims4()?;
xs.upsample_nearest2d(2 * h, 2 * w)?
}
Some((_h, _w)) => crate::utils::upsample_nearest2d(xs)?, // [h, w], None, None),
Some((h, w)) => xs.upsample_nearest2d(h, w)?,
};
self.conv.forward(&xs)
}