Fix the dilated convolutions. (#659)

This commit is contained in:
Laurent Mazare
2023-08-29 16:37:42 +01:00
committed by GitHub
parent a044907ffc
commit 71221559d3
3 changed files with 17 additions and 17 deletions

View File

@ -92,13 +92,13 @@ __device__ void conv2d(
const size_t src_idx0 = b_idx * src_s[0];
A d = 0;
for (size_t w_offset = 0; w_offset < w_k; ++w_offset) {
size_t src_w = (stride * dst_w + w_offset) * dilation;
size_t src_w = stride * dst_w + w_offset * dilation;
if (src_w < padding || src_w >= w_in + padding) {
continue;
}
src_w -= padding;
for (size_t h_offset = 0; h_offset < h_k; ++h_offset) {
size_t src_h = (stride * dst_h + h_offset) * dilation;
size_t src_h = stride * dst_h + h_offset * dilation;
if (src_h < padding || src_h >= h_in + padding) {
continue;
}