mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
Bugfix for avg-pool + add some test. (#365)
This commit is contained in:
@ -660,6 +660,8 @@ impl Map1 for AvgPool2D {
|
|||||||
let mut sum = T::zero();
|
let mut sum = T::zero();
|
||||||
for m in 0..k_h {
|
for m in 0..k_h {
|
||||||
for n in 0..k_w {
|
for n in 0..k_w {
|
||||||
|
let m = k_h * h_idx + m;
|
||||||
|
let n = k_w * w_idx + n;
|
||||||
sum += src[src_index + m * stride_h + n * stride_w]
|
sum += src[src_index + m * stride_h + n * stride_w]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
17
candle-core/tests/pool_tests.rs
Normal file
17
candle-core/tests/pool_tests.rs
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
mod test_utils;
|
||||||
|
use candle_core::{Device, Tensor};
|
||||||
|
|
||||||
|
// https://github.com/huggingface/candle/issues/364
|
||||||
|
#[test]
|
||||||
|
fn avg_pool2d() -> anyhow::Result<()> {
|
||||||
|
let device = Device::Cpu;
|
||||||
|
|
||||||
|
let data: Vec<f32> = vec![
|
||||||
|
1., 1., 1., 1., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
||||||
|
];
|
||||||
|
let t = Tensor::from_vec(data, (1, 1, 4, 4), &device)?;
|
||||||
|
|
||||||
|
let pool = t.avg_pool2d((2, 2), (2, 2))?.squeeze(0)?.squeeze(0)?;
|
||||||
|
assert_eq!(pool.to_vec2::<f32>()?, [[0.5f32, 1.], [1., 1.]]);
|
||||||
|
Ok(())
|
||||||
|
}
|
Reference in New Issue
Block a user