diff --git a/candle-core/src/npy.rs b/candle-core/src/npy.rs index f3a75965..1dfcc9b4 100644 --- a/candle-core/src/npy.rs +++ b/candle-core/src/npy.rs @@ -250,8 +250,6 @@ impl Tensor { if header.fortran_order { return Err(Error::Npy("fortran order not supported".to_string())); } - let mut data: Vec = vec![]; - reader.read_to_end(&mut data)?; Self::from_reader(header.shape(), header.descr, &mut reader) } diff --git a/candle-core/tests/npy.py b/candle-core/tests/npy.py new file mode 100644 index 00000000..0fd2778a --- /dev/null +++ b/candle-core/tests/npy.py @@ -0,0 +1,9 @@ +import numpy as np +x = np.arange(10) + +# Write a npy file. +np.save("test.npy", x) + +# Write multiple values to a npz file. +values = { "x": x, "x_plus_one": x + 1 } +np.savez("test.npz", **values) diff --git a/candle-core/tests/serialization_tests.rs b/candle-core/tests/serialization_tests.rs new file mode 100644 index 00000000..415306f4 --- /dev/null +++ b/candle-core/tests/serialization_tests.rs @@ -0,0 +1,24 @@ +use candle_core::{DType, Result, Tensor}; + +#[test] +fn npy() -> Result<()> { + let npy = Tensor::read_npy("tests/test.npy")?; + assert_eq!( + npy.to_dtype(DType::U8)?.to_vec1::()?, + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + ); + Ok(()) +} + +#[test] +fn npz() -> Result<()> { + let npz = Tensor::read_npz("tests/test.npz")?; + assert_eq!(npz.len(), 2); + assert_eq!(npz[0].0, "x"); + assert_eq!(npz[1].0, "x_plus_one"); + assert_eq!( + npz[1].1.to_dtype(DType::U8)?.to_vec1::()?, + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + ); + Ok(()) +} diff --git a/candle-core/tests/test.npy b/candle-core/tests/test.npy new file mode 100644 index 00000000..a3ff5af9 Binary files /dev/null and b/candle-core/tests/test.npy differ diff --git a/candle-core/tests/test.npz b/candle-core/tests/test.npz new file mode 100644 index 00000000..b6683caa Binary files /dev/null and b/candle-core/tests/test.npz differ