From e32c89d90cdab35fb89909e555c41e3d54920a26 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sat, 23 Sep 2023 22:57:42 +0100 Subject: [PATCH] Add the buffered safetensor wrapper. (#948) --- candle-core/src/safetensors.rs | 30 ++++++++++++++++++++++++++++++ candle-nn/src/var_builder.rs | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 62 insertions(+) diff --git a/candle-core/src/safetensors.rs b/candle-core/src/safetensors.rs index 12df7fbe..02d54e68 100644 --- a/candle-core/src/safetensors.rs +++ b/candle-core/src/safetensors.rs @@ -349,6 +349,36 @@ impl MmapedSafetensors { } } +pub struct BufferedSafetensors { + safetensors: yoke::Yoke, Vec>, +} + +impl BufferedSafetensors { + /// Creates a wrapper around a binary buffer and deserialize the safetensors header. + pub fn new(buffer: Vec) -> Result { + let safetensors = yoke::Yoke::, Vec>::try_attach_to_cart( + buffer, + |data: &[u8]| { + let st = safetensors::SafeTensors::deserialize(data)?; + Ok::<_, Error>(SafeTensors_(st)) + }, + )?; + Ok(Self { safetensors }) + } + + pub fn load(&self, name: &str, dev: &Device) -> Result { + self.get(name)?.load(dev) + } + + pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> { + self.safetensors.get().0.tensors() + } + + pub fn get(&self, name: &str) -> Result> { + Ok(self.safetensors.get().0.tensor(name)?) + } +} + pub struct MmapedFile { path: std::path::PathBuf, inner: memmap2::Mmap, diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs index 7b733e0c..220bae1b 100644 --- a/candle-nn/src/var_builder.rs +++ b/candle-nn/src/var_builder.rs @@ -351,6 +351,32 @@ impl SimpleBackend for candle::safetensors::MmapedSafetensors { } } +impl SimpleBackend for candle::safetensors::BufferedSafetensors { + fn get( + &self, + s: Shape, + name: &str, + _: crate::Init, + dtype: DType, + dev: &Device, + ) -> Result { + let tensor = self.load(name, dev)?.to_dtype(dtype)?; + if tensor.shape() != &s { + Err(candle::Error::UnexpectedShape { + msg: format!("shape mismatch for {name}"), + expected: s, + got: tensor.shape().clone(), + } + .bt())? + } + Ok(tensor) + } + + fn contains_tensor(&self, name: &str) -> bool { + self.get(name).is_ok() + } +} + impl<'a> VarBuilder<'a> { fn new(backend: Box, dtype: DType, device: Device) -> Self { let data = TensorData { @@ -417,6 +443,12 @@ impl<'a> VarBuilder<'a> { Ok(Self::new(Box::new(tensors), dtype, dev.clone())) } + /// Initializes a `VarBuilder` from a binary builder in the safetensor format. + pub fn from_buffered_safetensors(data: Vec, dtype: DType, dev: &Device) -> Result { + let tensors = candle::safetensors::BufferedSafetensors::new(data)?; + Ok(Self::new(Box::new(tensors), dtype, dev.clone())) + } + /// Initializes a `VarBuilder` that retrieves tensors stored in a numpy npz file. pub fn from_npz>(p: P, dtype: DType, dev: &Device) -> Result { let npz = candle::npy::NpzTensors::new(p)?;