diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 0ec19559..10c6cc4a 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -660,6 +660,8 @@ impl Map1 for AvgPool2D { let mut sum = T::zero(); for m in 0..k_h { for n in 0..k_w { + let m = k_h * h_idx + m; + let n = k_w * w_idx + n; sum += src[src_index + m * stride_h + n * stride_w] } } diff --git a/candle-core/tests/pool_tests.rs b/candle-core/tests/pool_tests.rs new file mode 100644 index 00000000..574182ca --- /dev/null +++ b/candle-core/tests/pool_tests.rs @@ -0,0 +1,17 @@ +mod test_utils; +use candle_core::{Device, Tensor}; + +// https://github.com/huggingface/candle/issues/364 +#[test] +fn avg_pool2d() -> anyhow::Result<()> { + let device = Device::Cpu; + + let data: Vec = vec![ + 1., 1., 1., 1., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., + ]; + let t = Tensor::from_vec(data, (1, 1, 4, 4), &device)?; + + let pool = t.avg_pool2d((2, 2), (2, 2))?.squeeze(0)?.squeeze(0)?; + assert_eq!(pool.to_vec2::()?, [[0.5f32, 1.], [1., 1.]]); + Ok(()) +}