From a325c1aa50fe31679dc24b6e6a796e9470698059 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Thu, 10 Aug 2023 21:02:35 +0200 Subject: [PATCH] Upsample test + bugfix. (#399) --- candle-core/src/cpu_backend.rs | 4 ++-- candle-core/tests/pool_tests.rs | 22 +++++++++++++++++++++- 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index fa24c434..6d129680 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -729,10 +729,10 @@ impl Map1 for UpsampleNearest2D { let scale_h = src_h as f64 / dst_h as f64; let scale_w = src_w as f64 / dst_w as f64; let mut dst = vec![T::zero(); b_sz * c * dst_h * dst_w]; - let src_h_idxs = (0..src_h) + let src_h_idxs = (0..dst_h) .map(|h_idx| usize::min(src_h - 1, (h_idx as f64 * scale_h) as usize)) .collect::>(); - let src_w_idxs = (0..src_w) + let src_w_idxs = (0..dst_w) .map(|w_idx| usize::min(src_w - 1, (w_idx as f64 * scale_w) as usize)) .collect::>(); for b_idx in 0..b_sz { diff --git a/candle-core/tests/pool_tests.rs b/candle-core/tests/pool_tests.rs index c8ddef97..73bf7434 100644 --- a/candle-core/tests/pool_tests.rs +++ b/candle-core/tests/pool_tests.rs @@ -1,5 +1,5 @@ mod test_utils; -use candle_core::{Device, Tensor}; +use candle_core::{Device, IndexOp, Tensor}; // https://github.com/huggingface/candle/issues/364 #[test] @@ -59,3 +59,23 @@ fn avg_pool2d_pytorch() -> anyhow::Result<()> { assert_eq!(test_utils::to_vec3_round(pool, 4)?, [[[0.085]], [[0.0078]]]); Ok(()) } + +#[test] +fn upsample_nearest2d() -> anyhow::Result<()> { + let t = Tensor::arange(0f32, 6f32, &Device::Cpu)?.reshape((1, 1, 2, 3))?; + let upsampled = t.upsample_nearest2d(4, 6)?.i(0)?.i(0)?; + assert_eq!( + t.i(0)?.i(0)?.to_vec2::()?, + [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]] + ); + assert_eq!( + upsampled.to_vec2::()?, + [ + [0.0, 0.0, 1.0, 1.0, 2.0, 2.0], + [0.0, 0.0, 1.0, 1.0, 2.0, 2.0], + [3.0, 3.0, 4.0, 4.0, 5.0, 5.0], + [3.0, 3.0, 4.0, 4.0, 5.0, 5.0] + ] + ); + Ok(()) +}