Skeleton for the avg-pool2d and upsample-nearest2d ops. (#337)

* Skeleton for the avg-pool2d and upsample-nearest2d ops.

* Preliminary conv2d support.
This commit is contained in:
Laurent Mazare
2023-08-07 17:15:38 +02:00
committed by GitHub
parent f53a333ea9
commit 2345b8ce3f
7 changed files with 88 additions and 17 deletions

View File

@ -55,6 +55,11 @@ impl Tensor {
kernel: rhs,
..
}
| Op::Conv2D {
arg: lhs,
kernel: rhs,
..
}
| Op::CustomOp2(lhs, rhs, _)
| Op::Binary(lhs, rhs, _)
| Op::Gather(lhs, rhs, _)
@ -81,6 +86,8 @@ impl Tensor {
}
}
Op::Reshape(node)
| Op::UpsampleNearest2D(node)
| Op::AvgPool2D { arg: node, .. }
| Op::Copy(node)
| Op::Broadcast(node)
| Op::Cmp(node, _)
@ -163,6 +170,11 @@ impl Tensor {
*f_sum_grad = f_sum_grad.add(&f_grad)?;
}
Op::Conv1D { .. } => Err(Error::BackwardNotSupported { op: "conv1d" })?,
Op::Conv2D { .. } => Err(Error::BackwardNotSupported { op: "conv2d" })?,
Op::AvgPool2D { .. } => Err(Error::BackwardNotSupported { op: "avg-pool2d" })?,
Op::UpsampleNearest2D { .. } => Err(Error::BackwardNotSupported {
op: "upsample-nearest2d",
})?,
Op::Gather(arg, indexes, dim) => {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.scatter_add(indexes, &grad, *dim)?;

View File

@ -80,6 +80,21 @@ pub enum Op {
stride: usize,
},
#[allow(dead_code)]
Conv2D {
arg: Tensor,
kernel: Tensor,
padding: usize,
stride: usize,
},
AvgPool2D {
arg: Tensor,
kernel_size: (usize, usize),
stride: (usize, usize),
},
UpsampleNearest2D(Tensor),
Cat(Vec<Tensor>, usize),
#[allow(dead_code)] // add is currently unused.

View File

@ -266,6 +266,24 @@ impl Storage {
}
}
pub(crate) fn avg_pool2d(
&self,
_layout: &Layout,
_kernel_size: (usize, usize),
_stride: (usize, usize),
) -> Result<Self> {
todo!()
}
pub(crate) fn upsample_nearest2d(
&self,
_layout: &Layout,
_h: usize,
_w: usize,
) -> Result<Self> {
todo!()
}
pub(crate) fn where_cond(
&self,
layout: &Layout,

View File

@ -817,6 +817,35 @@ impl Tensor {
Ok(from_storage(storage, out_dims, op, false))
}
pub fn conv2d(&self, _kernel: &Self, _padding: usize, _stride: usize) -> Result<Self> {
todo!()
}
pub fn upsample_nearest2d(&self, target_h: usize, target_w: usize) -> Result<Self> {
let (n, c, _h, _w) = self.dims4()?;
let op = BackpropOp::new1(self, Op::UpsampleNearest2D);
let storage = self
.storage()
.upsample_nearest2d(self.layout(), target_h, target_w)?;
Ok(from_storage(storage, (n, c, target_h, target_w), op, false))
}
pub fn avg_pool2d(&self, kernel_size: (usize, usize), stride: (usize, usize)) -> Result<Self> {
let (n, c, h, w) = self.dims4()?;
// https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html#torch.nn.AvgPool2d
let h_out = (h - kernel_size.0) / stride.0 + 1;
let w_out = (w - kernel_size.1) / stride.1 + 1;
let op = BackpropOp::new1(self, |arg| Op::AvgPool2D {
arg,
kernel_size,
stride,
});
let storage = self
.storage()
.avg_pool2d(self.layout(), kernel_size, stride)?;
Ok(from_storage(storage, (n, c, h_out, w_out), op, false))
}
/// Returns the matrix-multiplication of the input tensor with the other provided tensor.
///
/// # Arguments