mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Add SliceSafetensors. (#2179)
* Add SlicedSafetensors. * And add some testing.
This commit is contained in:
@ -349,6 +349,30 @@ impl MmapedSafetensors {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub struct SliceSafetensors<'a> {
|
||||||
|
safetensors: SafeTensors<'a>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> SliceSafetensors<'a> {
|
||||||
|
/// Creates a wrapper around a binary buffer and deserialize the safetensors header.
|
||||||
|
pub fn new(buffer: &'a [u8]) -> Result<Self> {
|
||||||
|
let safetensors = safetensors::SafeTensors::deserialize(buffer)?;
|
||||||
|
Ok(Self { safetensors })
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn load(&self, name: &str, dev: &Device) -> Result<Tensor> {
|
||||||
|
self.safetensors.tensor(name)?.load(dev)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> {
|
||||||
|
self.safetensors.tensors()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get(&self, name: &str) -> Result<st::TensorView<'_>> {
|
||||||
|
Ok(self.safetensors.tensor(name)?)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub struct BufferedSafetensors {
|
pub struct BufferedSafetensors {
|
||||||
safetensors: yoke::Yoke<SafeTensors_<'static>, Vec<u8>>,
|
safetensors: yoke::Yoke<SafeTensors_<'static>, Vec<u8>>,
|
||||||
}
|
}
|
||||||
|
@ -1,5 +1,31 @@
|
|||||||
use candle_core::{DType, Result, Tensor};
|
use candle_core::{DType, Result, Tensor};
|
||||||
|
|
||||||
|
struct TmpFile(std::path::PathBuf);
|
||||||
|
|
||||||
|
impl TmpFile {
|
||||||
|
fn create(base: &str) -> TmpFile {
|
||||||
|
let filename = std::env::temp_dir().join(format!(
|
||||||
|
"candle-{}-{}-{:?}",
|
||||||
|
base,
|
||||||
|
std::process::id(),
|
||||||
|
std::thread::current().id(),
|
||||||
|
));
|
||||||
|
TmpFile(filename)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::convert::AsRef<std::path::Path> for TmpFile {
|
||||||
|
fn as_ref(&self) -> &std::path::Path {
|
||||||
|
self.0.as_path()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for TmpFile {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
std::fs::remove_file(&self.0).unwrap()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn npy() -> Result<()> {
|
fn npy() -> Result<()> {
|
||||||
let npy = Tensor::read_npy("tests/test.npy")?;
|
let npy = Tensor::read_npy("tests/test.npy")?;
|
||||||
@ -22,3 +48,24 @@ fn npz() -> Result<()> {
|
|||||||
);
|
);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn safetensors() -> Result<()> {
|
||||||
|
use candle_core::safetensors::Load;
|
||||||
|
|
||||||
|
let tmp_file = TmpFile::create("st");
|
||||||
|
let t = Tensor::arange(0f32, 24f32, &candle_core::Device::Cpu)?;
|
||||||
|
t.save_safetensors("t", &tmp_file)?;
|
||||||
|
// Load from file.
|
||||||
|
let st = candle_core::safetensors::load(&tmp_file, &candle_core::Device::Cpu)?;
|
||||||
|
let t2 = st.get("t").unwrap();
|
||||||
|
let diff = (&t - t2)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(diff, 0f32);
|
||||||
|
// Load from bytes.
|
||||||
|
let bytes = std::fs::read(tmp_file)?;
|
||||||
|
let st = candle_core::safetensors::SliceSafetensors::new(&bytes)?;
|
||||||
|
let t2 = st.get("t").unwrap().load(&candle_core::Device::Cpu);
|
||||||
|
let diff = (&t - t2)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(diff, 0f32);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
@ -487,6 +487,12 @@ impl<'a> VarBuilder<'a> {
|
|||||||
Ok(Self::from_backend(Box::new(tensors), dtype, dev.clone()))
|
Ok(Self::from_backend(Box::new(tensors), dtype, dev.clone()))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Initializes a `VarBuilder` from a binary builder in the safetensor format.
|
||||||
|
pub fn from_slice_safetensors(data: Vec<u8>, dtype: DType, dev: &Device) -> Result<Self> {
|
||||||
|
let tensors = candle::safetensors::BufferedSafetensors::new(data)?;
|
||||||
|
Ok(Self::from_backend(Box::new(tensors), dtype, dev.clone()))
|
||||||
|
}
|
||||||
|
|
||||||
/// Initializes a `VarBuilder` that retrieves tensors stored in a numpy npz file.
|
/// Initializes a `VarBuilder` that retrieves tensors stored in a numpy npz file.
|
||||||
pub fn from_npz<P: AsRef<std::path::Path>>(p: P, dtype: DType, dev: &Device) -> Result<Self> {
|
pub fn from_npz<P: AsRef<std::path::Path>>(p: P, dtype: DType, dev: &Device) -> Result<Self> {
|
||||||
let npz = candle::npy::NpzTensors::new(p)?;
|
let npz = candle::npy::NpzTensors::new(p)?;
|
||||||
|
Reference in New Issue
Block a user