From ca26198b95aefe31ae5fd7845e16d18d646c8b32 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Mon, 28 Aug 2023 16:45:12 +0100 Subject: [PATCH] Fix the cpu kernel for conv-transpose. (#643) --- candle-core/src/cpu_backend.rs | 24 +++++++-------- candle-core/tests/conv_tests.rs | 54 +++++++++++++++++++++++++++------ 2 files changed, 55 insertions(+), 23 deletions(-) diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index bc4470d6..07fc78fc 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -1199,25 +1199,23 @@ impl<'a> Map2 for ConvTranspose2D<'a> { let dst_s2 = out_w; let dst_s3 = 1; for b_idx in 0..p.b_size { - for out_y in 0..out_h as i32 { - for out_x in 0..out_w as i32 { - let inp_x = out_x * p.stride as i32 - p.padding as i32; - let inp_y = out_y * p.stride as i32 - p.padding as i32; + for inp_y in 0..p.i_h { + for inp_x in 0..p.i_w { + let out_x = (inp_x * p.stride) as i32 - p.padding as i32; + let out_y = (inp_y * p.stride) as i32 - p.padding as i32; for k_y in 0..p.k_h as i32 { for k_x in 0..p.k_w as i32 { let k_index = k_y as usize * k_s2 + k_x as usize * k_s3; - let inp_y = inp_y + k_y; - let inp_x = inp_x + k_x; - if inp_x < 0 || inp_y < 0 { + let out_y = out_y + k_y; + let out_x = out_x + k_x; + if out_x < 0 || out_y < 0 { continue; } - let inp_x = inp_x as usize; - let inp_y = inp_y as usize; - if inp_x < p.i_w && inp_y < p.i_h { + let out_x = out_x as usize; + let out_y = out_y as usize; + if out_x < out_w && out_y < out_h { let inp_index = b_idx * inp_s0 + inp_y * inp_s2 + inp_x * inp_s3; - let dst_index = b_idx * dst_s0 - + out_y as usize * dst_s2 - + out_x as usize * dst_s3; + let dst_index = b_idx * dst_s0 + out_y * dst_s2 + out_x * dst_s3; for c_out in 0..p.c_out { for c_in in 0..p.c_in { let k_index = k_index + c_out * k_s1 + c_in * k_s0; diff --git a/candle-core/tests/conv_tests.rs b/candle-core/tests/conv_tests.rs index 2bd4e9df..4fe76378 100644 --- a/candle-core/tests/conv_tests.rs +++ b/candle-core/tests/conv_tests.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use candle_core::{test_device, test_utils, Device, Tensor}; +use candle_core::{test_device, test_utils, Device, IndexOp, Tensor}; /* This test is based on the following script. import torch @@ -76,6 +76,11 @@ print(t.flatten()) print(w.flatten()) res = torch.nn.functional.conv2d(t, w) print(res.flatten()) + +w_t = w.transpose(0, 1) +res = torch.nn.functional.conv_transpose2d(t, w_t) +print(res.shape) +print(res) */ fn conv2d(dev: &Device) -> Result<()> { let t = Tensor::new( @@ -117,6 +122,33 @@ fn conv2d(dev: &Device) -> Result<()> { 10.389, 3.6023, -4.2808, 0.2672, 5.3646, -5.2023, -2.1955, -9.4075 ] ); + if dev.is_cpu() { + let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1)?; + assert_eq!(res.dims(), [1, 2, 7, 7]); + assert_eq!( + test_utils::to_vec3_round(&res.i(0)?, 4)?, + [ + [ + [-1.9918, 2.6797, -0.4599, -1.6037, 1.4131, -2.4012, 2.9277], + [1.8016, -3.5361, 1.0757, 3.5395, -8.2168, -3.2023, 0.5375], + [0.8243, 1.8675, 7.8929, -4.0746, -6.4415, 5.1139, 1.6889], + [0.2722, 8.9679, 3.3477, 1.8514, -4.2896, -3.8228, -7.5632], + [-8.5412, -5.8142, -7.1587, -1.6095, 0.4651, 0.2748, -2.0985], + [2.0833, -0.6482, -12.1692, -4.1284, -2.9765, -0.0656, -4.5114], + [5.307, 2.6957, 2.3087, 1.0478, 0.7808, -1.1519, -0.9579] + ], + [ + [1.089, 0.1872, -0.6408, -0.9897, 0.8503, 1.1019, -0.9211], + [-0.1741, -0.2915, 4.2472, 1.9417, 1.65, 0.6303, -4.7131], + [1.6555, 2.4026, -2.9293, 2.9953, 0.5328, 3.5873, -0.9621], + [-1.4289, -3.2787, 4.1747, -6.0341, -4.6341, -5.7945, 4.142], + [7.5973, 6.4431, 5.9872, 2.1639, -8.6566, 3.3143, -3.4059], + [-0.8775, -3.048, 11.6543, 0.6442, 2.3218, -0.4765, 1.1516], + [-5.5423, -2.5188, 1.0754, -0.0563, -2.9386, -1.1504, 1.0171] + ] + ] + ); + } Ok(()) } @@ -287,16 +319,18 @@ fn conv2d_grad() -> Result<()> { assert_eq!(grad_w.dims(), [2, 4, 3, 3]); assert_eq!( test_utils::to_vec1_round(&grad_t.flatten_all()?, 4)?, - // THIS IS WRONG AT THE MOMENT [ - 1.7442, -10.1747, -9.9426, 0.0, 0.0, -1.7046, -21.2248, 30.8435, 0.0, 0.0, -18.713, - -1.0547, -7.8746, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 169.3047, - 46.0812, 40.6937, 0.0, 0.0, -85.8156, 4.537, 53.2871, 0.0, 0.0, -59.632, -35.9725, - -7.1689, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 48.823, 8.9794, - 42.3011, 0.0, 0.0, -58.9268, 32.907, -50.6863, 0.0, 0.0, -0.9706, -3.9175, -4.2594, - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 72.8229, 25.8492, 8.9871, - 0.0, 0.0, -136.2584, 40.1739, 88.9583, 0.0, 0.0, -53.465, -40.7102, -24.9406, 0.0, 0.0, - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 + 9.2868, -2.8352, -5.7117, 3.3817, -7.7094, -19.1549, 7.016, 29.1037, 9.3411, 34.7339, + -22.8726, 24.3502, -39.88, -14.007, 21.076, 9.9419, 13.6333, -34.6796, 11.2073, + -6.2617, 7.7209, -6.3224, -16.6373, -1.0837, -20.2215, 21.7302, -0.3744, -4.0573, + 5.8163, -3.6529, -30.7319, 14.5468, 87.699, 31.6035, 4.5304, -89.785, -75.3709, + -57.4327, -7.5602, 92.9585, 18.791, -4.6311, -159.7521, -42.4656, -47.2644, 52.8768, + 37.3172, 48.9978, 12.8192, 2.014, -8.9826, 20.1759, 16.621, 12.0599, 15.3849, 19.9979, + 2.5725, -15.2197, 72.6244, -10.7496, 2.2541, -31.2003, 3.753, -0.2049, 9.7574, -0.6824, + 5.2107, -40.4361, -22.5891, -61.6085, 17.2837, 20.4149, 37.5454, 5.2262, 6.8126, + 23.5361, 23.6173, -9.9866, -9.1324, 4.8664, -35.0617, -26.1023, 63.4757, 25.8144, + -39.2069, -70.6834, -46.9565, 2.3252, 41.8093, 82.4205, -28.626, -11.7812, -35.3284, + -10.2771, -28.5694, -9.1258, 7.213, -9.0459, -9.6222, -11.2544 ] ); assert_eq!(