From cd225bd3b1279392819da24cc058fcb74bc6428f Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Wed, 9 Aug 2023 17:12:23 +0200 Subject: [PATCH] More testing for avg-pool2d. (#366) * More testing for avg-pool2d. * Another fix. * Add a max-pool test with non-divisible kernel sizes. --- candle-core/src/cpu_backend.rs | 4 ++-- candle-core/tests/pool_tests.rs | 38 ++++++++++++++++++++++++++++++--- 2 files changed, 37 insertions(+), 5 deletions(-) diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 10c6cc4a..155df1e9 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -660,8 +660,8 @@ impl Map1 for AvgPool2D { let mut sum = T::zero(); for m in 0..k_h { for n in 0..k_w { - let m = k_h * h_idx + m; - let n = k_w * w_idx + n; + let m = s_h * h_idx + m; + let n = s_w * w_idx + n; sum += src[src_index + m * stride_h + n * stride_w] } } diff --git a/candle-core/tests/pool_tests.rs b/candle-core/tests/pool_tests.rs index 574182ca..7335e378 100644 --- a/candle-core/tests/pool_tests.rs +++ b/candle-core/tests/pool_tests.rs @@ -4,14 +4,46 @@ use candle_core::{Device, Tensor}; // https://github.com/huggingface/candle/issues/364 #[test] fn avg_pool2d() -> anyhow::Result<()> { - let device = Device::Cpu; - let data: Vec = vec![ 1., 1., 1., 1., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., ]; - let t = Tensor::from_vec(data, (1, 1, 4, 4), &device)?; + let t = Tensor::from_vec(data, (1, 1, 4, 4), &Device::Cpu)?; let pool = t.avg_pool2d((2, 2), (2, 2))?.squeeze(0)?.squeeze(0)?; assert_eq!(pool.to_vec2::()?, [[0.5f32, 1.], [1., 1.]]); Ok(()) } + +/* This test corresponds to the following PyTorch script. +import torch +torch.manual_seed(4242) + +t = torch.randn((1, 2, 4, 4)) +print(t.flatten()) +res = torch.nn.functional.avg_pool2d(t, 2) +print(res) +*/ +#[test] +fn avg_pool2d_pytorch() -> anyhow::Result<()> { + let t = Tensor::new( + &[ + 0.4056f32, -0.8689, -0.0773, -1.5630, -2.8012, -1.5059, 0.3972, 1.0852, 0.4997, 3.0616, + 1.6541, 0.0964, -0.8338, -1.6523, -0.8323, -0.1699, 0.0823, 0.3526, 0.6843, 0.2395, + 1.2279, -0.9287, -1.7030, 0.1370, 0.6047, 0.3770, -0.6266, 0.3529, 2.2013, -0.6836, + 0.2477, 1.3127, + ], + &Device::Cpu, + )? + .reshape((1, 2, 4, 4))?; + let pool = t.avg_pool2d((2, 2), (2, 2))?.squeeze(0)?; + assert_eq!( + test_utils::to_vec3_round(pool, 4)?, + [ + [[-1.1926, -0.0395], [0.2688, 0.1871]], + [[0.1835, -0.1606], [0.6249, 0.3217]] + ] + ); + let pool = t.avg_pool2d((3, 3), (3, 3))?.squeeze(0)?; + assert_eq!(test_utils::to_vec3_round(pool, 4)?, [[[0.085]], [[0.0078]]]); + Ok(()) +}