Add SliceSafetensors. (#2179)

* Add SlicedSafetensors.

* And add some testing.
This commit is contained in:
Laurent Mazare
2024-05-11 13:15:42 +02:00
committed by GitHub
parent 9cff7bc3f4
commit 21f82a5155
3 changed files with 77 additions and 0 deletions

View File

@ -487,6 +487,12 @@ impl<'a> VarBuilder<'a> {
Ok(Self::from_backend(Box::new(tensors), dtype, dev.clone()))
}
/// Initializes a `VarBuilder` from a binary builder in the safetensor format.
pub fn from_slice_safetensors(data: Vec<u8>, dtype: DType, dev: &Device) -> Result<Self> {
let tensors = candle::safetensors::BufferedSafetensors::new(data)?;
Ok(Self::from_backend(Box::new(tensors), dtype, dev.clone()))
}
/// Initializes a `VarBuilder` that retrieves tensors stored in a numpy npz file.
pub fn from_npz<P: AsRef<std::path::Path>>(p: P, dtype: DType, dev: &Device) -> Result<Self> {
let npz = candle::npy::NpzTensors::new(p)?;