From 71221559d306b1d504820a9561533cb521ffb39a Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Tue, 29 Aug 2023 16:37:42 +0100 Subject: [PATCH] Fix the dilated convolutions. (#659) --- candle-core/src/cpu_backend.rs | 6 +++--- candle-core/tests/conv_tests.rs | 24 ++++++++++++------------ candle-kernels/src/conv.cu | 4 ++-- 3 files changed, 17 insertions(+), 17 deletions(-) diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 60fac0c9..d4615b0a 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -1064,7 +1064,7 @@ impl<'a> Map2 for Conv1D<'a> { let dst_idx = dst_idx + b_idx * p.c_out * l_out; for dst_l in 0..l_out { let dst_idx = dst_idx + dst_l; - let src_l = (p.stride * dst_l + offset) * p.dilation; + let src_l = p.stride * dst_l + offset * p.dilation; if src_l < p.padding || src_l >= p.padding + p.l_in { continue; } @@ -1141,14 +1141,14 @@ impl<'a> Map2 for Conv2D<'a> { let dst_idx = dst_idx + b_idx * p.c_out * out_h * out_w; for dst_h in 0..out_h { let dst_idx = dst_idx + dst_h * out_w; - let src_h = (p.stride * dst_h + offset_h) * p.dilation; + let src_h = p.stride * dst_h + offset_h * p.dilation; if src_h < p.padding || src_h >= p.i_h + p.padding { continue; } let src_h = src_h - p.padding; for dst_w in 0..out_w { let dst_idx = dst_idx + dst_w; - let src_w = (p.stride * dst_w + offset_w) * p.dilation; + let src_w = p.stride * dst_w + offset_w * p.dilation; if src_w < p.padding || src_w >= p.i_w + p.padding { continue; } diff --git a/candle-core/tests/conv_tests.rs b/candle-core/tests/conv_tests.rs index 05015995..8196a27e 100644 --- a/candle-core/tests/conv_tests.rs +++ b/candle-core/tests/conv_tests.rs @@ -423,24 +423,24 @@ fn conv2d_grad(dev: &Device) -> Result<()> { test_utils::to_vec3_round(&grad_w.i(0)?, 2)?, [ [ - [28.34, -45.75, 7.32], - [0.72, -35.28, 19.23], - [-28.29, 20.89, -5.18] + [28.34, -7.91, -45.75], + [21.03, 3.86, 29.86], + [0.72, -36.58, -35.28] ], [ - [-16.04, -16.38, 32.12], - [57.5, 25.81, 11.96], - [-18.66, 8.48, -9.92] + [-16.04, 11.53, -16.38], + [29.62, -16.32, -48.35], + [57.5, 28.29, 25.81] ], [ - [2.93, 1.57, -23.76], - [12.74, -26.2, -17.88], - [-14.98, -9.35, 12.2] + [2.93, -19.6, 1.57], + [27.15, 53.88, -24.64], + [12.74, -22.6, -26.2] ], [ - [-0.18, -6.82, 20.79], - [-2.54, 27.11, -10.11], - [-0.41, -3.18, -0.07] + [-0.18, -14.86, -6.82], + [-19.55, -2.72, 45.9], + [-2.54, 36.97, 27.11] ] ] ); diff --git a/candle-kernels/src/conv.cu b/candle-kernels/src/conv.cu index c67a4300..91f4c7b2 100644 --- a/candle-kernels/src/conv.cu +++ b/candle-kernels/src/conv.cu @@ -92,13 +92,13 @@ __device__ void conv2d( const size_t src_idx0 = b_idx * src_s[0]; A d = 0; for (size_t w_offset = 0; w_offset < w_k; ++w_offset) { - size_t src_w = (stride * dst_w + w_offset) * dilation; + size_t src_w = stride * dst_w + w_offset * dilation; if (src_w < padding || src_w >= w_in + padding) { continue; } src_w -= padding; for (size_t h_offset = 0; h_offset < h_k; ++h_offset) { - size_t src_h = (stride * dst_h + h_offset) * dilation; + size_t src_h = stride * dst_h + h_offset * dilation; if (src_h < padding || src_h >= h_in + padding) { continue; }