From 7473c4cecaeb4cde39452c0b3442a3f6105c53d7 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Thu, 12 Oct 2023 15:25:05 +0200 Subject: [PATCH] Fix the npy read function and add some testing. (#1080) --- candle-core/src/npy.rs | 2 -- candle-core/tests/npy.py | 9 +++++++++ candle-core/tests/serialization_tests.rs | 24 +++++++++++++++++++++++ candle-core/tests/test.npy | Bin 0 -> 208 bytes candle-core/tests/test.npz | Bin 0 -> 668 bytes 5 files changed, 33 insertions(+), 2 deletions(-) create mode 100644 candle-core/tests/npy.py create mode 100644 candle-core/tests/serialization_tests.rs create mode 100644 candle-core/tests/test.npy create mode 100644 candle-core/tests/test.npz diff --git a/candle-core/src/npy.rs b/candle-core/src/npy.rs index f3a75965..1dfcc9b4 100644 --- a/candle-core/src/npy.rs +++ b/candle-core/src/npy.rs @@ -250,8 +250,6 @@ impl Tensor { if header.fortran_order { return Err(Error::Npy("fortran order not supported".to_string())); } - let mut data: Vec = vec![]; - reader.read_to_end(&mut data)?; Self::from_reader(header.shape(), header.descr, &mut reader) } diff --git a/candle-core/tests/npy.py b/candle-core/tests/npy.py new file mode 100644 index 00000000..0fd2778a --- /dev/null +++ b/candle-core/tests/npy.py @@ -0,0 +1,9 @@ +import numpy as np +x = np.arange(10) + +# Write a npy file. +np.save("test.npy", x) + +# Write multiple values to a npz file. +values = { "x": x, "x_plus_one": x + 1 } +np.savez("test.npz", **values) diff --git a/candle-core/tests/serialization_tests.rs b/candle-core/tests/serialization_tests.rs new file mode 100644 index 00000000..415306f4 --- /dev/null +++ b/candle-core/tests/serialization_tests.rs @@ -0,0 +1,24 @@ +use candle_core::{DType, Result, Tensor}; + +#[test] +fn npy() -> Result<()> { + let npy = Tensor::read_npy("tests/test.npy")?; + assert_eq!( + npy.to_dtype(DType::U8)?.to_vec1::()?, + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + ); + Ok(()) +} + +#[test] +fn npz() -> Result<()> { + let npz = Tensor::read_npz("tests/test.npz")?; + assert_eq!(npz.len(), 2); + assert_eq!(npz[0].0, "x"); + assert_eq!(npz[1].0, "x_plus_one"); + assert_eq!( + npz[1].1.to_dtype(DType::U8)?.to_vec1::()?, + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + ); + Ok(()) +} diff --git a/candle-core/tests/test.npy b/candle-core/tests/test.npy new file mode 100644 index 0000000000000000000000000000000000000000..a3ff5af9540510ac6518b5a2478bc1a40f992371 GIT binary patch literal 208 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1ZlWC!@qoAIaUsO_*m=~X4l#&V(cT3DEP6dh= wXCxM+0{I$-20EHL3bhL411<(AV1&|4P?{M^vp{K9D9r|?*`YKCl;(ue0P!XpoB#j- literal 0 HcmV?d00001 diff --git a/candle-core/tests/test.npz b/candle-core/tests/test.npz new file mode 100644 index 0000000000000000000000000000000000000000..b6683caab0e0e9e66cb5ca649aab4980a1e650df GIT binary patch literal 668 zcmWIWW@Zs#fB;1X8xBpr3qTGCvoeSRzLK#d@ni)#7KxtMe%?72}p)?1S=7iD#D8ZokVdg