Support dilation in conv-transpose2d. (#671)

This commit is contained in:
Laurent Mazare
2023-08-30 09:22:00 +01:00
committed by GitHub
parent 9b25113393
commit 393690387f
4 changed files with 40 additions and 17 deletions

View File

@ -155,15 +155,15 @@ __device__ void conv_transpose2d(
const size_t src_idx0 = b_idx * src_s[0];
A d = 0;
for (int k_x = 0; k_x < (int)w_k; ++k_x) {
// let out_x = inp_x * p.stride + k_x - p.padding;
int inp_x_stride = (int)(out_x + padding) - k_x;
// let out_x = inp_x * p.stride + k_x * p.dilation - p.padding;
int inp_x_stride = (int)(out_x + padding) - k_x * dilation;
if (inp_x_stride < 0 || inp_x_stride % stride) {
continue;
}
int inp_x = inp_x_stride / stride;
if (inp_x >= w_in) continue;
for (int k_y = 0; k_y < (int)h_k; ++k_y) {
int inp_y_stride = (int)(out_y + padding) - k_y;
int inp_y_stride = (int)(out_y + padding) - k_y * dilation;
if (inp_y_stride < 0 || inp_y_stride % stride) {
continue;
}