mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
Add SliceSafetensors. (#2179)
* Add SlicedSafetensors. * And add some testing.
This commit is contained in:
@ -487,6 +487,12 @@ impl<'a> VarBuilder<'a> {
|
||||
Ok(Self::from_backend(Box::new(tensors), dtype, dev.clone()))
|
||||
}
|
||||
|
||||
/// Initializes a `VarBuilder` from a binary builder in the safetensor format.
|
||||
pub fn from_slice_safetensors(data: Vec<u8>, dtype: DType, dev: &Device) -> Result<Self> {
|
||||
let tensors = candle::safetensors::BufferedSafetensors::new(data)?;
|
||||
Ok(Self::from_backend(Box::new(tensors), dtype, dev.clone()))
|
||||
}
|
||||
|
||||
/// Initializes a `VarBuilder` that retrieves tensors stored in a numpy npz file.
|
||||
pub fn from_npz<P: AsRef<std::path::Path>>(p: P, dtype: DType, dev: &Device) -> Result<Self> {
|
||||
let npz = candle::npy::NpzTensors::new(p)?;
|
||||
|
Reference in New Issue
Block a user