Add conv-transpose. (#635)

* Add conv-transpose.

* Return zeros for now.

* Naive CPU implementation.

* Add a conv-transpose test + fix the cpu implementation.

* Add a second test.
This commit is contained in:
Laurent Mazare
2023-08-28 10:10:12 +01:00
committed by GitHub
parent 26e1b40992
commit 3cca89cc70
9 changed files with 244 additions and 0 deletions

View File

@ -60,6 +60,11 @@ impl Tensor {
kernel: rhs,
..
}
| Op::ConvTranspose2D {
arg: lhs,
kernel: rhs,
..
}
| Op::CustomOp2(lhs, rhs, _)
| Op::Binary(lhs, rhs, _)
| Op::Gather(lhs, rhs, _)
@ -188,6 +193,9 @@ impl Tensor {
}
Op::Conv1D { .. } => Err(Error::BackwardNotSupported { op: "conv1d" })?,
Op::Conv2D { .. } => Err(Error::BackwardNotSupported { op: "conv2d" })?,
Op::ConvTranspose2D { .. } => Err(Error::BackwardNotSupported {
op: "conv-transpose2d",
})?,
Op::AvgPool2D { .. } => Err(Error::BackwardNotSupported { op: "avg-pool2d" })?,
Op::MaxPool2D { .. } => Err(Error::BackwardNotSupported { op: "max-pool2d" })?,
Op::UpsampleNearest2D { .. } => Err(Error::BackwardNotSupported {