mirror of
https://github.com/huggingface/candle.git
synced 2025-06-22 04:22:50 +00:00
Add SliceSafetensors. (#2179)
* Add SlicedSafetensors. * And add some testing.
This commit is contained in:
@ -349,6 +349,30 @@ impl MmapedSafetensors {
|
||||
}
|
||||
}
|
||||
|
||||
pub struct SliceSafetensors<'a> {
|
||||
safetensors: SafeTensors<'a>,
|
||||
}
|
||||
|
||||
impl<'a> SliceSafetensors<'a> {
|
||||
/// Creates a wrapper around a binary buffer and deserialize the safetensors header.
|
||||
pub fn new(buffer: &'a [u8]) -> Result<Self> {
|
||||
let safetensors = safetensors::SafeTensors::deserialize(buffer)?;
|
||||
Ok(Self { safetensors })
|
||||
}
|
||||
|
||||
pub fn load(&self, name: &str, dev: &Device) -> Result<Tensor> {
|
||||
self.safetensors.tensor(name)?.load(dev)
|
||||
}
|
||||
|
||||
pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> {
|
||||
self.safetensors.tensors()
|
||||
}
|
||||
|
||||
pub fn get(&self, name: &str) -> Result<st::TensorView<'_>> {
|
||||
Ok(self.safetensors.tensor(name)?)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct BufferedSafetensors {
|
||||
safetensors: yoke::Yoke<SafeTensors_<'static>, Vec<u8>>,
|
||||
}
|
||||
|
Reference in New Issue
Block a user