Files
candle/candle-core/tests/serialization_tests.rs
Laurent Mazare 21f82a5155 Add SliceSafetensors. (#2179)
* Add SlicedSafetensors.

* And add some testing.
2024-05-11 13:15:42 +02:00

72 lines
1.9 KiB
Rust

use candle_core::{DType, Result, Tensor};
struct TmpFile(std::path::PathBuf);
impl TmpFile {
fn create(base: &str) -> TmpFile {
let filename = std::env::temp_dir().join(format!(
"candle-{}-{}-{:?}",
base,
std::process::id(),
std::thread::current().id(),
));
TmpFile(filename)
}
}
impl std::convert::AsRef<std::path::Path> for TmpFile {
fn as_ref(&self) -> &std::path::Path {
self.0.as_path()
}
}
impl Drop for TmpFile {
fn drop(&mut self) {
std::fs::remove_file(&self.0).unwrap()
}
}
#[test]
fn npy() -> Result<()> {
let npy = Tensor::read_npy("tests/test.npy")?;
assert_eq!(
npy.to_dtype(DType::U8)?.to_vec1::<u8>()?,
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
);
Ok(())
}
#[test]
fn npz() -> Result<()> {
let npz = Tensor::read_npz("tests/test.npz")?;
assert_eq!(npz.len(), 2);
assert_eq!(npz[0].0, "x");
assert_eq!(npz[1].0, "x_plus_one");
assert_eq!(
npz[1].1.to_dtype(DType::U8)?.to_vec1::<u8>()?,
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
);
Ok(())
}
#[test]
fn safetensors() -> Result<()> {
use candle_core::safetensors::Load;
let tmp_file = TmpFile::create("st");
let t = Tensor::arange(0f32, 24f32, &candle_core::Device::Cpu)?;
t.save_safetensors("t", &tmp_file)?;
// Load from file.
let st = candle_core::safetensors::load(&tmp_file, &candle_core::Device::Cpu)?;
let t2 = st.get("t").unwrap();
let diff = (&t - t2)?.abs()?.sum_all()?.to_vec0::<f32>()?;
assert_eq!(diff, 0f32);
// Load from bytes.
let bytes = std::fs::read(tmp_file)?;
let st = candle_core::safetensors::SliceSafetensors::new(&bytes)?;
let t2 = st.get("t").unwrap().load(&candle_core::Device::Cpu);
let diff = (&t - t2)?.abs()?.sum_all()?.to_vec0::<f32>()?;
assert_eq!(diff, 0f32);
Ok(())
}