Move the test-utils bits to a shared place. (#619)

This commit is contained in:
Laurent Mazare
2023-08-27 09:42:22 +01:00
committed by GitHub
parent a8b39dd7b7
commit 5320aa6b7d
17 changed files with 34 additions and 88 deletions

View File

@ -1,5 +1,4 @@
mod test_utils;
use candle_core::{Device, IndexOp, Result, Tensor};
use candle_core::{test_device, test_utils, Device, IndexOp, Result, Tensor};
// https://github.com/huggingface/candle/issues/364
fn avg_pool2d(dev: &Device) -> Result<()> {
@ -56,14 +55,17 @@ fn avg_pool2d_pytorch(dev: &Device) -> Result<()> {
.reshape((1, 2, 4, 4))?;
let pool = t.avg_pool2d((2, 2), (2, 2))?.squeeze(0)?;
assert_eq!(
test_utils::to_vec3_round(pool, 4)?,
test_utils::to_vec3_round(&pool, 4)?,
[
[[-1.1926, -0.0395], [0.2688, 0.1871]],
[[0.1835, -0.1606], [0.6249, 0.3217]]
]
);
let pool = t.avg_pool2d((3, 3), (3, 3))?.squeeze(0)?;
assert_eq!(test_utils::to_vec3_round(pool, 4)?, [[[0.085]], [[0.0078]]]);
assert_eq!(
test_utils::to_vec3_round(&pool, 4)?,
[[[0.085]], [[0.0078]]]
);
let t = t.reshape((1, 1, 4, 8))?;
let pool = t.avg_pool2d((2, 2), (2, 2))?.squeeze(0)?.squeeze(0)?;