add max_pool2d (#371)

Co-authored-by: 赵理山 <ls@zhaolishandeMacBook-Air.local>
This commit is contained in:
LeeeSe
2023-08-10 01:05:26 +08:00
committed by GitHub
parent 1892bd139c
commit a5c5a893aa
9 changed files with 115 additions and 0 deletions

View File

@ -14,6 +14,18 @@ fn avg_pool2d() -> anyhow::Result<()> {
Ok(())
}
#[test]
fn max_pool2d() -> anyhow::Result<()> {
let data: Vec<f32> = vec![
1., 2., 1., 3., 0., 0., 1., 1., 1., 1., 1., 1., 5., 1., 1., 1.,
];
let t = Tensor::from_vec(data, (1, 1, 4, 4), &Device::Cpu)?;
let pool = t.max_pool2d((2, 2), (2, 2))?.squeeze(0)?.squeeze(0)?;
assert_eq!(pool.to_vec2::<f32>()?, [[2f32, 3.], [5., 1.]]);
Ok(())
}
/* This test corresponds to the following PyTorch script.
import torch
torch.manual_seed(4242)