mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Add the buffered safetensor wrapper. (#948)
This commit is contained in:
@ -349,6 +349,36 @@ impl MmapedSafetensors {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub struct BufferedSafetensors {
|
||||||
|
safetensors: yoke::Yoke<SafeTensors_<'static>, Vec<u8>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BufferedSafetensors {
|
||||||
|
/// Creates a wrapper around a binary buffer and deserialize the safetensors header.
|
||||||
|
pub fn new(buffer: Vec<u8>) -> Result<Self> {
|
||||||
|
let safetensors = yoke::Yoke::<SafeTensors_<'static>, Vec<u8>>::try_attach_to_cart(
|
||||||
|
buffer,
|
||||||
|
|data: &[u8]| {
|
||||||
|
let st = safetensors::SafeTensors::deserialize(data)?;
|
||||||
|
Ok::<_, Error>(SafeTensors_(st))
|
||||||
|
},
|
||||||
|
)?;
|
||||||
|
Ok(Self { safetensors })
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn load(&self, name: &str, dev: &Device) -> Result<Tensor> {
|
||||||
|
self.get(name)?.load(dev)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> {
|
||||||
|
self.safetensors.get().0.tensors()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get(&self, name: &str) -> Result<st::TensorView<'_>> {
|
||||||
|
Ok(self.safetensors.get().0.tensor(name)?)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub struct MmapedFile {
|
pub struct MmapedFile {
|
||||||
path: std::path::PathBuf,
|
path: std::path::PathBuf,
|
||||||
inner: memmap2::Mmap,
|
inner: memmap2::Mmap,
|
||||||
|
@ -351,6 +351,32 @@ impl SimpleBackend for candle::safetensors::MmapedSafetensors {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl SimpleBackend for candle::safetensors::BufferedSafetensors {
|
||||||
|
fn get(
|
||||||
|
&self,
|
||||||
|
s: Shape,
|
||||||
|
name: &str,
|
||||||
|
_: crate::Init,
|
||||||
|
dtype: DType,
|
||||||
|
dev: &Device,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
let tensor = self.load(name, dev)?.to_dtype(dtype)?;
|
||||||
|
if tensor.shape() != &s {
|
||||||
|
Err(candle::Error::UnexpectedShape {
|
||||||
|
msg: format!("shape mismatch for {name}"),
|
||||||
|
expected: s,
|
||||||
|
got: tensor.shape().clone(),
|
||||||
|
}
|
||||||
|
.bt())?
|
||||||
|
}
|
||||||
|
Ok(tensor)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn contains_tensor(&self, name: &str) -> bool {
|
||||||
|
self.get(name).is_ok()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl<'a> VarBuilder<'a> {
|
impl<'a> VarBuilder<'a> {
|
||||||
fn new(backend: Box<dyn SimpleBackend + 'a>, dtype: DType, device: Device) -> Self {
|
fn new(backend: Box<dyn SimpleBackend + 'a>, dtype: DType, device: Device) -> Self {
|
||||||
let data = TensorData {
|
let data = TensorData {
|
||||||
@ -417,6 +443,12 @@ impl<'a> VarBuilder<'a> {
|
|||||||
Ok(Self::new(Box::new(tensors), dtype, dev.clone()))
|
Ok(Self::new(Box::new(tensors), dtype, dev.clone()))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Initializes a `VarBuilder` from a binary builder in the safetensor format.
|
||||||
|
pub fn from_buffered_safetensors(data: Vec<u8>, dtype: DType, dev: &Device) -> Result<Self> {
|
||||||
|
let tensors = candle::safetensors::BufferedSafetensors::new(data)?;
|
||||||
|
Ok(Self::new(Box::new(tensors), dtype, dev.clone()))
|
||||||
|
}
|
||||||
|
|
||||||
/// Initializes a `VarBuilder` that retrieves tensors stored in a numpy npz file.
|
/// Initializes a `VarBuilder` that retrieves tensors stored in a numpy npz file.
|
||||||
pub fn from_npz<P: AsRef<std::path::Path>>(p: P, dtype: DType, dev: &Device) -> Result<Self> {
|
pub fn from_npz<P: AsRef<std::path::Path>>(p: P, dtype: DType, dev: &Device) -> Result<Self> {
|
||||||
let npz = candle::npy::NpzTensors::new(p)?;
|
let npz = candle::npy::NpzTensors::new(p)?;
|
||||||
|
Reference in New Issue
Block a user