mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
add max_pool2d (#371)
Co-authored-by: 赵理山 <ls@zhaolishandeMacBook-Air.local>
This commit is contained in:
@ -14,6 +14,18 @@ fn avg_pool2d() -> anyhow::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn max_pool2d() -> anyhow::Result<()> {
|
||||
let data: Vec<f32> = vec![
|
||||
1., 2., 1., 3., 0., 0., 1., 1., 1., 1., 1., 1., 5., 1., 1., 1.,
|
||||
];
|
||||
let t = Tensor::from_vec(data, (1, 1, 4, 4), &Device::Cpu)?;
|
||||
|
||||
let pool = t.max_pool2d((2, 2), (2, 2))?.squeeze(0)?.squeeze(0)?;
|
||||
assert_eq!(pool.to_vec2::<f32>()?, [[2f32, 3.], [5., 1.]]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/* This test corresponds to the following PyTorch script.
|
||||
import torch
|
||||
torch.manual_seed(4242)
|
||||
|
Reference in New Issue
Block a user