mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
[Proposal] Remove SafeTensor wrapper (allows finer control for users).
This commit is contained in:
@ -1,5 +1,6 @@
|
||||
use crate::{DType, Device, Error, Result, Tensor, WithDType};
|
||||
use safetensors::tensor as st;
|
||||
pub use safetensors::tensor::SafeTensors;
|
||||
use std::borrow::Cow;
|
||||
|
||||
impl From<DType> for st::Dtype {
|
||||
@ -62,7 +63,7 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
fn convert_<T: WithDType>(view: st::TensorView<'_>, device: &Device) -> Result<Tensor> {
|
||||
fn convert_<T: WithDType>(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> {
|
||||
let v = view.data();
|
||||
let size_in_bytes = T::DTYPE.size_in_bytes();
|
||||
let elem_count = v.len() / size_in_bytes;
|
||||
@ -101,7 +102,17 @@ fn convert_back_<T: WithDType>(mut vs: Vec<T>) -> Vec<u8> {
|
||||
unsafe { Vec::from_raw_parts(ptr, length, capacity) }
|
||||
}
|
||||
|
||||
pub fn convert(view: st::TensorView<'_>, device: &Device) -> Result<Tensor> {
|
||||
pub trait Load {
|
||||
fn load(&self, device: &Device) -> Result<Tensor>;
|
||||
}
|
||||
|
||||
impl<'a> Load for st::TensorView<'a> {
|
||||
fn load(&self, device: &Device) -> Result<Tensor> {
|
||||
convert(self, device)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn convert(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> {
|
||||
match view.dtype() {
|
||||
st::Dtype::U8 => convert_::<u8>(view, device),
|
||||
st::Dtype::U32 => convert_::<u8>(view, device),
|
||||
@ -126,13 +137,6 @@ pub fn convert_back(tensor: &Tensor) -> Result<Vec<u8>> {
|
||||
}
|
||||
}
|
||||
|
||||
// If Rust allowed for self-referential struct, we could store both the Mmap buffer and the
|
||||
// SafeTensor bits in the same struct and avoid having the final users calling two methods.
|
||||
// We could try using the ouroboros crate or equivalent for this at some point.
|
||||
// Wrap the SafeTensors main module so as to provide accessors with the candle types for errors,
|
||||
// dtypes, etc
|
||||
pub struct SafeTensors<'a>(st::SafeTensors<'a>);
|
||||
|
||||
pub struct MmapedFile(memmap2::Mmap);
|
||||
|
||||
impl MmapedFile {
|
||||
@ -150,33 +154,7 @@ impl MmapedFile {
|
||||
|
||||
pub fn deserialize(&self) -> Result<SafeTensors<'_>> {
|
||||
let st = safetensors::SafeTensors::deserialize(&self.0)?;
|
||||
Ok(SafeTensors(st))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> SafeTensors<'a> {
|
||||
pub fn from_buffer(buffer: &'a [u8]) -> Result<Self> {
|
||||
let st = safetensors::SafeTensors::deserialize(buffer)?;
|
||||
Ok(SafeTensors(st))
|
||||
}
|
||||
|
||||
pub fn tensor(&self, name: &str, device: &Device) -> Result<Tensor> {
|
||||
convert(self.0.tensor(name)?, device)
|
||||
}
|
||||
|
||||
pub fn tensors(&self, device: &Device) -> Result<Vec<(String, Tensor)>> {
|
||||
self.0
|
||||
.tensors()
|
||||
.into_iter()
|
||||
.map(|(name, tensor_view)| {
|
||||
let tensor = convert(tensor_view, device)?;
|
||||
Ok((name, tensor))
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn names(&self) -> Vec<&String> {
|
||||
self.0.names()
|
||||
Ok(st)
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user