mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
Upsample test + bugfix. (#399)
This commit is contained in:
@ -729,10 +729,10 @@ impl Map1 for UpsampleNearest2D {
|
|||||||
let scale_h = src_h as f64 / dst_h as f64;
|
let scale_h = src_h as f64 / dst_h as f64;
|
||||||
let scale_w = src_w as f64 / dst_w as f64;
|
let scale_w = src_w as f64 / dst_w as f64;
|
||||||
let mut dst = vec![T::zero(); b_sz * c * dst_h * dst_w];
|
let mut dst = vec![T::zero(); b_sz * c * dst_h * dst_w];
|
||||||
let src_h_idxs = (0..src_h)
|
let src_h_idxs = (0..dst_h)
|
||||||
.map(|h_idx| usize::min(src_h - 1, (h_idx as f64 * scale_h) as usize))
|
.map(|h_idx| usize::min(src_h - 1, (h_idx as f64 * scale_h) as usize))
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
let src_w_idxs = (0..src_w)
|
let src_w_idxs = (0..dst_w)
|
||||||
.map(|w_idx| usize::min(src_w - 1, (w_idx as f64 * scale_w) as usize))
|
.map(|w_idx| usize::min(src_w - 1, (w_idx as f64 * scale_w) as usize))
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
for b_idx in 0..b_sz {
|
for b_idx in 0..b_sz {
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
mod test_utils;
|
mod test_utils;
|
||||||
use candle_core::{Device, Tensor};
|
use candle_core::{Device, IndexOp, Tensor};
|
||||||
|
|
||||||
// https://github.com/huggingface/candle/issues/364
|
// https://github.com/huggingface/candle/issues/364
|
||||||
#[test]
|
#[test]
|
||||||
@ -59,3 +59,23 @@ fn avg_pool2d_pytorch() -> anyhow::Result<()> {
|
|||||||
assert_eq!(test_utils::to_vec3_round(pool, 4)?, [[[0.085]], [[0.0078]]]);
|
assert_eq!(test_utils::to_vec3_round(pool, 4)?, [[[0.085]], [[0.0078]]]);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn upsample_nearest2d() -> anyhow::Result<()> {
|
||||||
|
let t = Tensor::arange(0f32, 6f32, &Device::Cpu)?.reshape((1, 1, 2, 3))?;
|
||||||
|
let upsampled = t.upsample_nearest2d(4, 6)?.i(0)?.i(0)?;
|
||||||
|
assert_eq!(
|
||||||
|
t.i(0)?.i(0)?.to_vec2::<f32>()?,
|
||||||
|
[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
upsampled.to_vec2::<f32>()?,
|
||||||
|
[
|
||||||
|
[0.0, 0.0, 1.0, 1.0, 2.0, 2.0],
|
||||||
|
[0.0, 0.0, 1.0, 1.0, 2.0, 2.0],
|
||||||
|
[3.0, 3.0, 4.0, 4.0, 5.0, 5.0],
|
||||||
|
[3.0, 3.0, 4.0, 4.0, 5.0, 5.0]
|
||||||
|
]
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user