diff --git a/candle-core/src/safetensors.rs b/candle-core/src/safetensors.rs index 02d54e68..5ea1f192 100644 --- a/candle-core/src/safetensors.rs +++ b/candle-core/src/safetensors.rs @@ -349,6 +349,30 @@ impl MmapedSafetensors { } } +pub struct SliceSafetensors<'a> { + safetensors: SafeTensors<'a>, +} + +impl<'a> SliceSafetensors<'a> { + /// Creates a wrapper around a binary buffer and deserialize the safetensors header. + pub fn new(buffer: &'a [u8]) -> Result { + let safetensors = safetensors::SafeTensors::deserialize(buffer)?; + Ok(Self { safetensors }) + } + + pub fn load(&self, name: &str, dev: &Device) -> Result { + self.safetensors.tensor(name)?.load(dev) + } + + pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> { + self.safetensors.tensors() + } + + pub fn get(&self, name: &str) -> Result> { + Ok(self.safetensors.tensor(name)?) + } +} + pub struct BufferedSafetensors { safetensors: yoke::Yoke, Vec>, } diff --git a/candle-core/tests/serialization_tests.rs b/candle-core/tests/serialization_tests.rs index 415306f4..f81350e6 100644 --- a/candle-core/tests/serialization_tests.rs +++ b/candle-core/tests/serialization_tests.rs @@ -1,5 +1,31 @@ use candle_core::{DType, Result, Tensor}; +struct TmpFile(std::path::PathBuf); + +impl TmpFile { + fn create(base: &str) -> TmpFile { + let filename = std::env::temp_dir().join(format!( + "candle-{}-{}-{:?}", + base, + std::process::id(), + std::thread::current().id(), + )); + TmpFile(filename) + } +} + +impl std::convert::AsRef for TmpFile { + fn as_ref(&self) -> &std::path::Path { + self.0.as_path() + } +} + +impl Drop for TmpFile { + fn drop(&mut self) { + std::fs::remove_file(&self.0).unwrap() + } +} + #[test] fn npy() -> Result<()> { let npy = Tensor::read_npy("tests/test.npy")?; @@ -22,3 +48,24 @@ fn npz() -> Result<()> { ); Ok(()) } + +#[test] +fn safetensors() -> Result<()> { + use candle_core::safetensors::Load; + + let tmp_file = TmpFile::create("st"); + let t = Tensor::arange(0f32, 24f32, &candle_core::Device::Cpu)?; + t.save_safetensors("t", &tmp_file)?; + // Load from file. + let st = candle_core::safetensors::load(&tmp_file, &candle_core::Device::Cpu)?; + let t2 = st.get("t").unwrap(); + let diff = (&t - t2)?.abs()?.sum_all()?.to_vec0::()?; + assert_eq!(diff, 0f32); + // Load from bytes. + let bytes = std::fs::read(tmp_file)?; + let st = candle_core::safetensors::SliceSafetensors::new(&bytes)?; + let t2 = st.get("t").unwrap().load(&candle_core::Device::Cpu); + let diff = (&t - t2)?.abs()?.sum_all()?.to_vec0::()?; + assert_eq!(diff, 0f32); + Ok(()) +} diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs index 68bd6f05..ebbc9084 100644 --- a/candle-nn/src/var_builder.rs +++ b/candle-nn/src/var_builder.rs @@ -487,6 +487,12 @@ impl<'a> VarBuilder<'a> { Ok(Self::from_backend(Box::new(tensors), dtype, dev.clone())) } + /// Initializes a `VarBuilder` from a binary builder in the safetensor format. + pub fn from_slice_safetensors(data: Vec, dtype: DType, dev: &Device) -> Result { + let tensors = candle::safetensors::BufferedSafetensors::new(data)?; + Ok(Self::from_backend(Box::new(tensors), dtype, dev.clone())) + } + /// Initializes a `VarBuilder` that retrieves tensors stored in a numpy npz file. pub fn from_npz>(p: P, dtype: DType, dev: &Device) -> Result { let npz = candle::npy::NpzTensors::new(p)?;