Fix the npy read function and add some testing. (#1080)

This commit is contained in:
Laurent Mazare
2023-10-12 15:25:05 +02:00
committed by GitHub
parent c096f02411
commit 7473c4ceca
5 changed files with 33 additions and 2 deletions

View File

@ -250,8 +250,6 @@ impl Tensor {
if header.fortran_order {
return Err(Error::Npy("fortran order not supported".to_string()));
}
let mut data: Vec<u8> = vec![];
reader.read_to_end(&mut data)?;
Self::from_reader(header.shape(), header.descr, &mut reader)
}

9
candle-core/tests/npy.py Normal file
View File

@ -0,0 +1,9 @@
import numpy as np
x = np.arange(10)
# Write a npy file.
np.save("test.npy", x)
# Write multiple values to a npz file.
values = { "x": x, "x_plus_one": x + 1 }
np.savez("test.npz", **values)

View File

@ -0,0 +1,24 @@
use candle_core::{DType, Result, Tensor};
#[test]
fn npy() -> Result<()> {
let npy = Tensor::read_npy("tests/test.npy")?;
assert_eq!(
npy.to_dtype(DType::U8)?.to_vec1::<u8>()?,
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
);
Ok(())
}
#[test]
fn npz() -> Result<()> {
let npz = Tensor::read_npz("tests/test.npz")?;
assert_eq!(npz.len(), 2);
assert_eq!(npz[0].0, "x");
assert_eq!(npz[1].0, "x_plus_one");
assert_eq!(
npz[1].1.to_dtype(DType::U8)?.to_vec1::<u8>()?,
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
);
Ok(())
}

BIN
candle-core/tests/test.npy Normal file

Binary file not shown.

BIN
candle-core/tests/test.npz Normal file

Binary file not shown.