mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
More non square testing (#582)
* Add more non square testing. * More testing.
This commit is contained in:
@ -9,6 +9,13 @@ fn avg_pool2d(dev: &Device) -> Result<()> {
|
|||||||
let t = Tensor::from_vec(data, (1, 1, 4, 4), dev)?;
|
let t = Tensor::from_vec(data, (1, 1, 4, 4), dev)?;
|
||||||
let pool = t.avg_pool2d((2, 2), (2, 2))?.squeeze(0)?.squeeze(0)?;
|
let pool = t.avg_pool2d((2, 2), (2, 2))?.squeeze(0)?.squeeze(0)?;
|
||||||
assert_eq!(pool.to_vec2::<f32>()?, [[0.5f32, 1.], [1., 1.]]);
|
assert_eq!(pool.to_vec2::<f32>()?, [[0.5f32, 1.], [1., 1.]]);
|
||||||
|
|
||||||
|
let data: Vec<f32> = vec![
|
||||||
|
1., 2., 1., 3., 0., 0., 1., 1., 1., 1., 1., 1., 5., 1., 1., 1.,
|
||||||
|
];
|
||||||
|
let t = Tensor::from_vec(data, (1, 1, 2, 8), dev)?;
|
||||||
|
let pool = t.avg_pool2d((2, 2), (2, 2))?.squeeze(0)?.squeeze(0)?;
|
||||||
|
assert_eq!(pool.to_vec2::<f32>()?, [[5. / 4., 6. / 4., 6. / 4., 1.]]);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -20,6 +27,10 @@ fn max_pool2d(dev: &Device) -> Result<()> {
|
|||||||
|
|
||||||
let pool = t.max_pool2d((2, 2), (2, 2))?.squeeze(0)?.squeeze(0)?;
|
let pool = t.max_pool2d((2, 2), (2, 2))?.squeeze(0)?.squeeze(0)?;
|
||||||
assert_eq!(pool.to_vec2::<f32>()?, [[2f32, 3.], [5., 1.]]);
|
assert_eq!(pool.to_vec2::<f32>()?, [[2f32, 3.], [5., 1.]]);
|
||||||
|
|
||||||
|
let t = t.reshape((1, 1, 2, 8))?;
|
||||||
|
let pool = t.max_pool2d((2, 2), (2, 2))?.squeeze(0)?.squeeze(0)?;
|
||||||
|
assert_eq!(pool.to_vec2::<f32>()?, [[2.0, 3.0, 5.0, 1.0]]);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -53,6 +64,16 @@ fn avg_pool2d_pytorch(dev: &Device) -> Result<()> {
|
|||||||
);
|
);
|
||||||
let pool = t.avg_pool2d((3, 3), (3, 3))?.squeeze(0)?;
|
let pool = t.avg_pool2d((3, 3), (3, 3))?.squeeze(0)?;
|
||||||
assert_eq!(test_utils::to_vec3_round(pool, 4)?, [[[0.085]], [[0.0078]]]);
|
assert_eq!(test_utils::to_vec3_round(pool, 4)?, [[[0.085]], [[0.0078]]]);
|
||||||
|
|
||||||
|
let t = t.reshape((1, 1, 4, 8))?;
|
||||||
|
let pool = t.avg_pool2d((2, 2), (2, 2))?.squeeze(0)?.squeeze(0)?;
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec2_round(&pool, 4)?,
|
||||||
|
[
|
||||||
|
[0.7745, 0.0276, -1.6983, 0.12],
|
||||||
|
[0.3542, 0.1625, 0.4542, -0.0014]
|
||||||
|
]
|
||||||
|
);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user