mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
Move the test-utils bits to a shared place. (#619)
This commit is contained in:
@ -1,5 +1,4 @@
|
||||
mod test_utils;
|
||||
use candle_core::{Device, IndexOp, Result, Tensor};
|
||||
use candle_core::{test_device, test_utils, Device, IndexOp, Result, Tensor};
|
||||
|
||||
// https://github.com/huggingface/candle/issues/364
|
||||
fn avg_pool2d(dev: &Device) -> Result<()> {
|
||||
@ -56,14 +55,17 @@ fn avg_pool2d_pytorch(dev: &Device) -> Result<()> {
|
||||
.reshape((1, 2, 4, 4))?;
|
||||
let pool = t.avg_pool2d((2, 2), (2, 2))?.squeeze(0)?;
|
||||
assert_eq!(
|
||||
test_utils::to_vec3_round(pool, 4)?,
|
||||
test_utils::to_vec3_round(&pool, 4)?,
|
||||
[
|
||||
[[-1.1926, -0.0395], [0.2688, 0.1871]],
|
||||
[[0.1835, -0.1606], [0.6249, 0.3217]]
|
||||
]
|
||||
);
|
||||
let pool = t.avg_pool2d((3, 3), (3, 3))?.squeeze(0)?;
|
||||
assert_eq!(test_utils::to_vec3_round(pool, 4)?, [[[0.085]], [[0.0078]]]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec3_round(&pool, 4)?,
|
||||
[[[0.085]], [[0.0078]]]
|
||||
);
|
||||
|
||||
let t = t.reshape((1, 1, 4, 8))?;
|
||||
let pool = t.avg_pool2d((2, 2), (2, 2))?.squeeze(0)?.squeeze(0)?;
|
||||
|
Reference in New Issue
Block a user