mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
More non square testing (#582)
* Add more non square testing. * More testing.
This commit is contained in:
@ -9,6 +9,13 @@ fn avg_pool2d(dev: &Device) -> Result<()> {
|
||||
let t = Tensor::from_vec(data, (1, 1, 4, 4), dev)?;
|
||||
let pool = t.avg_pool2d((2, 2), (2, 2))?.squeeze(0)?.squeeze(0)?;
|
||||
assert_eq!(pool.to_vec2::<f32>()?, [[0.5f32, 1.], [1., 1.]]);
|
||||
|
||||
let data: Vec<f32> = vec![
|
||||
1., 2., 1., 3., 0., 0., 1., 1., 1., 1., 1., 1., 5., 1., 1., 1.,
|
||||
];
|
||||
let t = Tensor::from_vec(data, (1, 1, 2, 8), dev)?;
|
||||
let pool = t.avg_pool2d((2, 2), (2, 2))?.squeeze(0)?.squeeze(0)?;
|
||||
assert_eq!(pool.to_vec2::<f32>()?, [[5. / 4., 6. / 4., 6. / 4., 1.]]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -20,6 +27,10 @@ fn max_pool2d(dev: &Device) -> Result<()> {
|
||||
|
||||
let pool = t.max_pool2d((2, 2), (2, 2))?.squeeze(0)?.squeeze(0)?;
|
||||
assert_eq!(pool.to_vec2::<f32>()?, [[2f32, 3.], [5., 1.]]);
|
||||
|
||||
let t = t.reshape((1, 1, 2, 8))?;
|
||||
let pool = t.max_pool2d((2, 2), (2, 2))?.squeeze(0)?.squeeze(0)?;
|
||||
assert_eq!(pool.to_vec2::<f32>()?, [[2.0, 3.0, 5.0, 1.0]]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -53,6 +64,16 @@ fn avg_pool2d_pytorch(dev: &Device) -> Result<()> {
|
||||
);
|
||||
let pool = t.avg_pool2d((3, 3), (3, 3))?.squeeze(0)?;
|
||||
assert_eq!(test_utils::to_vec3_round(pool, 4)?, [[[0.085]], [[0.0078]]]);
|
||||
|
||||
let t = t.reshape((1, 1, 4, 8))?;
|
||||
let pool = t.avg_pool2d((2, 2), (2, 2))?.squeeze(0)?.squeeze(0)?;
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&pool, 4)?,
|
||||
[
|
||||
[0.7745, 0.0276, -1.6983, 0.12],
|
||||
[0.3542, 0.1625, 0.4542, -0.0014]
|
||||
]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user