mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
More testing for avg-pool2d. (#366)
* More testing for avg-pool2d. * Another fix. * Add a max-pool test with non-divisible kernel sizes.
This commit is contained in:
@ -660,8 +660,8 @@ impl Map1 for AvgPool2D {
|
|||||||
let mut sum = T::zero();
|
let mut sum = T::zero();
|
||||||
for m in 0..k_h {
|
for m in 0..k_h {
|
||||||
for n in 0..k_w {
|
for n in 0..k_w {
|
||||||
let m = k_h * h_idx + m;
|
let m = s_h * h_idx + m;
|
||||||
let n = k_w * w_idx + n;
|
let n = s_w * w_idx + n;
|
||||||
sum += src[src_index + m * stride_h + n * stride_w]
|
sum += src[src_index + m * stride_h + n * stride_w]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -4,14 +4,46 @@ use candle_core::{Device, Tensor};
|
|||||||
// https://github.com/huggingface/candle/issues/364
|
// https://github.com/huggingface/candle/issues/364
|
||||||
#[test]
|
#[test]
|
||||||
fn avg_pool2d() -> anyhow::Result<()> {
|
fn avg_pool2d() -> anyhow::Result<()> {
|
||||||
let device = Device::Cpu;
|
|
||||||
|
|
||||||
let data: Vec<f32> = vec![
|
let data: Vec<f32> = vec![
|
||||||
1., 1., 1., 1., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
1., 1., 1., 1., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
||||||
];
|
];
|
||||||
let t = Tensor::from_vec(data, (1, 1, 4, 4), &device)?;
|
let t = Tensor::from_vec(data, (1, 1, 4, 4), &Device::Cpu)?;
|
||||||
|
|
||||||
let pool = t.avg_pool2d((2, 2), (2, 2))?.squeeze(0)?.squeeze(0)?;
|
let pool = t.avg_pool2d((2, 2), (2, 2))?.squeeze(0)?.squeeze(0)?;
|
||||||
assert_eq!(pool.to_vec2::<f32>()?, [[0.5f32, 1.], [1., 1.]]);
|
assert_eq!(pool.to_vec2::<f32>()?, [[0.5f32, 1.], [1., 1.]]);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* This test corresponds to the following PyTorch script.
|
||||||
|
import torch
|
||||||
|
torch.manual_seed(4242)
|
||||||
|
|
||||||
|
t = torch.randn((1, 2, 4, 4))
|
||||||
|
print(t.flatten())
|
||||||
|
res = torch.nn.functional.avg_pool2d(t, 2)
|
||||||
|
print(res)
|
||||||
|
*/
|
||||||
|
#[test]
|
||||||
|
fn avg_pool2d_pytorch() -> anyhow::Result<()> {
|
||||||
|
let t = Tensor::new(
|
||||||
|
&[
|
||||||
|
0.4056f32, -0.8689, -0.0773, -1.5630, -2.8012, -1.5059, 0.3972, 1.0852, 0.4997, 3.0616,
|
||||||
|
1.6541, 0.0964, -0.8338, -1.6523, -0.8323, -0.1699, 0.0823, 0.3526, 0.6843, 0.2395,
|
||||||
|
1.2279, -0.9287, -1.7030, 0.1370, 0.6047, 0.3770, -0.6266, 0.3529, 2.2013, -0.6836,
|
||||||
|
0.2477, 1.3127,
|
||||||
|
],
|
||||||
|
&Device::Cpu,
|
||||||
|
)?
|
||||||
|
.reshape((1, 2, 4, 4))?;
|
||||||
|
let pool = t.avg_pool2d((2, 2), (2, 2))?.squeeze(0)?;
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec3_round(pool, 4)?,
|
||||||
|
[
|
||||||
|
[[-1.1926, -0.0395], [0.2688, 0.1871]],
|
||||||
|
[[0.1835, -0.1606], [0.6249, 0.3217]]
|
||||||
|
]
|
||||||
|
);
|
||||||
|
let pool = t.avg_pool2d((3, 3), (3, 3))?.squeeze(0)?;
|
||||||
|
assert_eq!(test_utils::to_vec3_round(pool, 4)?, [[[0.085]], [[0.0078]]]);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user