Dilated convolutions (#657)

* Add the dilation parameter.

* Restore the basic optimizer example.

* Dilation support in cudnn.

* Use the dilation parameter in the cpu backend.

* More dilation support.

* No support for dilation in transposed convolutions.

* Add dilation to a test.

* Remove a print.

* Helper function.
This commit is contained in:
Laurent Mazare
2023-08-29 16:12:11 +01:00
committed by GitHub
parent ee8bb1bde1
commit a044907ffc
20 changed files with 231 additions and 45 deletions

View File

@ -8,6 +8,7 @@ __device__ void conv1d(
const size_t l_out,
const size_t stride,
const size_t padding,
const size_t dilation,
const size_t *info,
const T *src,
const T *kernel,
@ -36,7 +37,7 @@ __device__ void conv1d(
const size_t src_idx0 = b_idx * src_s[0];
A d = 0;
for (size_t offset = 0; offset < k_size; ++offset) {
size_t src_l = stride * dst_l + offset;
size_t src_l = (stride * dst_l + offset) * dilation;
if (src_l < padding || src_l >= padding + l_in) {
continue;
}
@ -58,6 +59,7 @@ __device__ void conv2d(
const size_t h_out,
const size_t stride,
const size_t padding,
const size_t dilation,
const size_t *info,
const T *src,
const T *kernel,
@ -90,13 +92,13 @@ __device__ void conv2d(
const size_t src_idx0 = b_idx * src_s[0];
A d = 0;
for (size_t w_offset = 0; w_offset < w_k; ++w_offset) {
size_t src_w = stride * dst_w + w_offset;
size_t src_w = (stride * dst_w + w_offset) * dilation;
if (src_w < padding || src_w >= w_in + padding) {
continue;
}
src_w -= padding;
for (size_t h_offset = 0; h_offset < h_k; ++h_offset) {
size_t src_h = stride * dst_h + h_offset;
size_t src_h = (stride * dst_h + h_offset) * dilation;
if (src_h < padding || src_h >= h_in + padding) {
continue;
}
@ -120,6 +122,7 @@ __device__ void conv_transpose2d(
const size_t stride,
const size_t padding,
const size_t out_padding,
const size_t dilation,
const size_t *info,
const T *src,
const T *kernel,
@ -335,12 +338,13 @@ extern "C" __global__ void FN_NAME( \
const size_t num_dims, \
const size_t stride, \
const size_t padding, \
const size_t dilation, \
const size_t *info, \
const TYPENAME *src, \
const TYPENAME *kernel, \
TYPENAME *dst \
) { \
conv1d<TYPENAME, TYPEACC>(src_numel, num_dims, stride, padding, info, src, kernel, dst); \
conv1d<TYPENAME, TYPEACC>(src_numel, num_dims, stride, padding, dilation, info, src, kernel, dst); \
} \
#define CONV2D_OP(TYPENAME, TYPEACC, FN_NAME) \
@ -350,12 +354,13 @@ extern "C" __global__ void FN_NAME( \
const size_t h_out, \
const size_t stride, \
const size_t padding, \
const size_t dilation, \
const size_t *info, \
const TYPENAME *src, \
const TYPENAME *kernel, \
TYPENAME *dst \
) { \
conv2d<TYPENAME, TYPEACC>(src_numel, w_out, h_out, stride, padding, info, src, kernel, dst); \
conv2d<TYPENAME, TYPEACC>(src_numel, w_out, h_out, stride, padding, dilation, info, src, kernel, dst); \
} \
#define CONVT2D_OP(TYPENAME, TYPEACC, FN_NAME) \
@ -366,12 +371,13 @@ extern "C" __global__ void FN_NAME( \
const size_t stride, \
const size_t padding, \
const size_t out_padding, \
const size_t dilation, \
const size_t *info, \
const TYPENAME *src, \
const TYPENAME *kernel, \
TYPENAME *dst \
) { \
conv_transpose2d<TYPENAME, TYPEACC>(src_numel, w_out, h_out, stride, padding, out_padding, info, src, kernel, dst); \
conv_transpose2d<TYPENAME, TYPEACC>(src_numel, w_out, h_out, stride, padding, out_padding, dilation, info, src, kernel, dst); \
} \
#define AVG_POOL2D_OP(TYPENAME, TYPEACC, FN_NAME) \