Upsample test + bugfix. (#399)

This commit is contained in:
Laurent Mazare
2023-08-10 21:02:35 +02:00
committed by GitHub
parent b6cf26e48e
commit a325c1aa50
2 changed files with 23 additions and 3 deletions

View File

@ -1,5 +1,5 @@
mod test_utils;
use candle_core::{Device, Tensor};
use candle_core::{Device, IndexOp, Tensor};
// https://github.com/huggingface/candle/issues/364
#[test]
@ -59,3 +59,23 @@ fn avg_pool2d_pytorch() -> anyhow::Result<()> {
assert_eq!(test_utils::to_vec3_round(pool, 4)?, [[[0.085]], [[0.0078]]]);
Ok(())
}
#[test]
fn upsample_nearest2d() -> anyhow::Result<()> {
let t = Tensor::arange(0f32, 6f32, &Device::Cpu)?.reshape((1, 1, 2, 3))?;
let upsampled = t.upsample_nearest2d(4, 6)?.i(0)?.i(0)?;
assert_eq!(
t.i(0)?.i(0)?.to_vec2::<f32>()?,
[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]
);
assert_eq!(
upsampled.to_vec2::<f32>()?,
[
[0.0, 0.0, 1.0, 1.0, 2.0, 2.0],
[0.0, 0.0, 1.0, 1.0, 2.0, 2.0],
[3.0, 3.0, 4.0, 4.0, 5.0, 5.0],
[3.0, 3.0, 4.0, 4.0, 5.0, 5.0]
]
);
Ok(())
}