mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Npy tweaks & error with path (#384)
* Simplify the npy writing. * Wrap the file path so as to provide better errors.
This commit is contained in:
@ -185,6 +185,13 @@ pub enum Error {
|
||||
#[error(transparent)]
|
||||
Wrapped(Box<dyn std::error::Error + Send + Sync>),
|
||||
|
||||
/// Adding path information to an error.
|
||||
#[error("path: {path:?} {inner}")]
|
||||
WithPath {
|
||||
inner: Box<Self>,
|
||||
path: std::path::PathBuf,
|
||||
},
|
||||
|
||||
#[error("{inner}\n{backtrace}")]
|
||||
WithBacktrace {
|
||||
inner: Box<Self>,
|
||||
@ -214,6 +221,13 @@ impl Error {
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_path<P: AsRef<std::path::Path>>(self, p: P) -> Self {
|
||||
Self::WithPath {
|
||||
inner: Box::new(self),
|
||||
path: p.as_ref().to_path_buf(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
|
@ -307,39 +307,39 @@ impl Tensor {
|
||||
header.push('\n');
|
||||
f.write_all(&[(header.len() % 256) as u8, (header.len() / 256) as u8])?;
|
||||
f.write_all(header.as_bytes())?;
|
||||
let elem_count = self.elem_count();
|
||||
let vs = self.flatten_all()?;
|
||||
match self.dtype() {
|
||||
DType::BF16 => {
|
||||
let vs = self.reshape(elem_count)?.to_vec1::<bf16>()?;
|
||||
let vs = vs.to_vec1::<bf16>()?;
|
||||
for &v in vs.reinterpret_cast() {
|
||||
f.write_u16::<LittleEndian>(v)?
|
||||
}
|
||||
}
|
||||
DType::F16 => {
|
||||
let vs = self.reshape(elem_count)?.to_vec1::<f16>()?;
|
||||
let vs = vs.to_vec1::<f16>()?;
|
||||
for &v in vs.reinterpret_cast() {
|
||||
f.write_u16::<LittleEndian>(v)?
|
||||
}
|
||||
}
|
||||
DType::F32 => {
|
||||
// TODO: Avoid using a buffer when data is already on the CPU.
|
||||
for v in self.reshape(elem_count)?.to_vec1::<f32>()? {
|
||||
for v in vs.to_vec1::<f32>()? {
|
||||
f.write_f32::<LittleEndian>(v)?
|
||||
}
|
||||
}
|
||||
DType::F64 => {
|
||||
for v in self.reshape(elem_count)?.to_vec1::<f64>()? {
|
||||
for v in vs.to_vec1::<f64>()? {
|
||||
f.write_f64::<LittleEndian>(v)?
|
||||
}
|
||||
}
|
||||
DType::U32 => {
|
||||
for v in self.reshape(elem_count)?.to_vec1::<u32>()? {
|
||||
for v in vs.to_vec1::<u32>()? {
|
||||
f.write_u32::<LittleEndian>(v)?
|
||||
}
|
||||
}
|
||||
DType::U8 => {
|
||||
let data = self.reshape(elem_count)?.to_vec1::<u8>()?;
|
||||
f.write_all(&data)?;
|
||||
let vs = vs.to_vec1::<u8>()?;
|
||||
f.write_all(&vs)?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
@ -373,7 +373,7 @@ pub struct NpzTensors {
|
||||
index_per_name: HashMap<String, usize>,
|
||||
path: std::path::PathBuf,
|
||||
// We do not store a zip reader as it needs mutable access to extract data. Instead we
|
||||
// re-create a zip reader each time.
|
||||
// re-create a zip reader for each tensor.
|
||||
}
|
||||
|
||||
impl NpzTensors {
|
||||
|
@ -257,7 +257,10 @@ pub fn save<P: AsRef<Path>>(tensors: &HashMap<&str, Tensor>, filename: P) -> Res
|
||||
Ok(st::serialize_to_file(tensors, &None, filename.as_ref())?)
|
||||
}
|
||||
|
||||
pub struct MmapedFile(memmap2::Mmap);
|
||||
pub struct MmapedFile {
|
||||
path: std::path::PathBuf,
|
||||
inner: memmap2::Mmap,
|
||||
}
|
||||
|
||||
impl MmapedFile {
|
||||
/// Creates a wrapper around a memory mapped file from which you can retrieve
|
||||
@ -267,13 +270,20 @@ impl MmapedFile {
|
||||
///
|
||||
/// The unsafe is inherited from [`memmap2::MmapOptions`].
|
||||
pub unsafe fn new<P: AsRef<std::path::Path>>(p: P) -> Result<Self> {
|
||||
let file = std::fs::File::open(p)?;
|
||||
let mmap = memmap2::MmapOptions::new().map(&file)?;
|
||||
Ok(Self(mmap))
|
||||
let p = p.as_ref();
|
||||
let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?;
|
||||
let inner = memmap2::MmapOptions::new()
|
||||
.map(&file)
|
||||
.map_err(|e| Error::from(e).with_path(p))?;
|
||||
Ok(Self {
|
||||
inner,
|
||||
path: p.to_path_buf(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn deserialize(&self) -> Result<SafeTensors<'_>> {
|
||||
let st = safetensors::SafeTensors::deserialize(&self.0)?;
|
||||
let st = safetensors::SafeTensors::deserialize(&self.inner)
|
||||
.map_err(|e| Error::from(e).with_path(&self.path))?;
|
||||
Ok(st)
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user