Bugfix for avg-pool + add some test. (#365)

This commit is contained in:
Laurent Mazare
2023-08-09 16:44:16 +02:00
committed by GitHub
parent 3a62aee91f
commit b80348d22f
2 changed files with 19 additions and 0 deletions

View File

@ -660,6 +660,8 @@ impl Map1 for AvgPool2D {
let mut sum = T::zero();
for m in 0..k_h {
for n in 0..k_w {
let m = k_h * h_idx + m;
let n = k_w * w_idx + n;
sum += src[src_index + m * stride_h + n * stride_w]
}
}

View File

@ -0,0 +1,17 @@
mod test_utils;
use candle_core::{Device, Tensor};
// https://github.com/huggingface/candle/issues/364
#[test]
fn avg_pool2d() -> anyhow::Result<()> {
let device = Device::Cpu;
let data: Vec<f32> = vec![
1., 1., 1., 1., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
];
let t = Tensor::from_vec(data, (1, 1, 4, 4), &device)?;
let pool = t.avg_pool2d((2, 2), (2, 2))?.squeeze(0)?.squeeze(0)?;
assert_eq!(pool.to_vec2::<f32>()?, [[0.5f32, 1.], [1., 1.]]);
Ok(())
}