Add SliceSafetensors. (#2179)

* Add SlicedSafetensors.

* And add some testing.
This commit is contained in:
Laurent Mazare
2024-05-11 13:15:42 +02:00
committed by GitHub
parent 9cff7bc3f4
commit 21f82a5155
3 changed files with 77 additions and 0 deletions

View File

@ -349,6 +349,30 @@ impl MmapedSafetensors {
}
}
pub struct SliceSafetensors<'a> {
safetensors: SafeTensors<'a>,
}
impl<'a> SliceSafetensors<'a> {
/// Creates a wrapper around a binary buffer and deserialize the safetensors header.
pub fn new(buffer: &'a [u8]) -> Result<Self> {
let safetensors = safetensors::SafeTensors::deserialize(buffer)?;
Ok(Self { safetensors })
}
pub fn load(&self, name: &str, dev: &Device) -> Result<Tensor> {
self.safetensors.tensor(name)?.load(dev)
}
pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> {
self.safetensors.tensors()
}
pub fn get(&self, name: &str) -> Result<st::TensorView<'_>> {
Ok(self.safetensors.tensor(name)?)
}
}
pub struct BufferedSafetensors {
safetensors: yoke::Yoke<SafeTensors_<'static>, Vec<u8>>,
}