From 01ea57da8c936e2b73b40848cf2eab9cadab94c9 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Fri, 11 Aug 2023 15:59:54 +0200 Subject: [PATCH] Fix the conv tests. (#409) --- candle-core/tests/conv_tests.rs | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/candle-core/tests/conv_tests.rs b/candle-core/tests/conv_tests.rs index 7ec83592..f955b4a5 100644 --- a/candle-core/tests/conv_tests.rs +++ b/candle-core/tests/conv_tests.rs @@ -12,6 +12,8 @@ print(t.flatten()) print(w.flatten()) res = torch.nn.functional.conv1d(t, w) print(res.flatten()) +res = torch.nn.functional.conv1d(t, w, padding=1) +print(res.flatten()) */ #[test] fn conv1d() -> Result<()> { @@ -41,14 +43,10 @@ fn conv1d() -> Result<()> { ); let res = t.conv1d(&w, /*padding*/ 1, 1)?; assert_eq!(res.dims(), [1, 2, 5]); - /* Note that the default for padding is different from PyTorch at the moment: instead of - padding with zeros, the edge value from the input tensor is used, i.e. this is similiar to: - t = torch.nn.functional.pad(t, (1, 1), mode='replicate') - res = torch.nn.functional.conv1d(t, w, padding=0) - */ + // Same as pytorch default padding: use zeros. assert_eq!( test_utils::to_vec1_round(&res.flatten_all()?, 4)?, - [2.5209, 2.6357, -1.3336, 4.1393, 0.4951, 3.6855, -1.1784, 3.5675, 0.5069, 4.9562] + [2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352] ); Ok(()) } @@ -68,7 +66,7 @@ fn conv1d_small() -> Result<()> { assert_eq!(res.dims(), [1, 1, 4]); assert_eq!( test_utils::to_vec1_round(&res.flatten_all()?, 4)?, - [0.4056, 0.4056, -0.8689, -0.0773], + [0.0, 0.4056, -0.8689, -0.0773], ); Ok(()) }