From fc67d878bb4a25cbeba361d0a31290f14beb9344 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Mon, 19 Feb 2024 09:04:49 +0100 Subject: [PATCH] Bugfix for conv-transpose1d (#1734) * Add a currently broken test. * Bugfix + fix test. --- candle-core/src/cpu_backend.rs | 1 + candle-core/tests/conv_tests.rs | 14 ++++++++++++++ 2 files changed, 15 insertions(+) diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index f912c1b2..05e8c979 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -1263,6 +1263,7 @@ impl<'a> Map2 for ConvTranspose1D<'a> { fn f(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result> { let p = self.0; let inp = &inp[inp_l.start_offset()..]; + let k = &k[k_l.start_offset()..]; let (inp_s0, inp_s1, inp_s2) = crate::shape::dims3(inp_l.stride())?; let (k_s0, k_s1, k_s2) = crate::shape::dims3(k_l.stride())?; let l_out = p.l_out(); diff --git a/candle-core/tests/conv_tests.rs b/candle-core/tests/conv_tests.rs index 211a1fe0..b967515d 100644 --- a/candle-core/tests/conv_tests.rs +++ b/candle-core/tests/conv_tests.rs @@ -18,6 +18,9 @@ w_t = w.transpose(0, 1) res = torch.nn.functional.conv_transpose1d(t, w_t) print(res.shape) print(res) +res = torch.nn.functional.conv_transpose1d(t, w_t, groups=2) +print(res.shape) +print(res) */ fn conv1d(dev: &Device) -> Result<()> { let t = Tensor::new( @@ -59,6 +62,17 @@ fn conv1d(dev: &Device) -> Result<()> { 4.7076, -5.9745, -0.8276, 1.621 ], ); + let res = t.conv_transpose1d(&w.transpose(0, 1)?, 0, 0, 1, 1, 2)?; + assert_eq!(res.dims(), [1, 4, 7]); + assert_eq!( + test_utils::to_vec2_round(&res.squeeze(0)?, 4)?, + [ + [-1.5596, -1.8099, 2.0407, 4.8764, -0.1743, -0.735, -0.7819], + [0.7816, 3.8152, -0.5926, 2.2515, -5.1844, -0.3157, 1.4721], + [1.6295, 0.52, 6.2611, 0.7109, 2.6315, -1.8793, 0.7113], + [1.0949, 1.0166, 1.7464, 2.4561, -0.79, -0.5119, 0.1488] + ] + ); Ok(()) }