diff --git a/candle-core/src/error.rs b/candle-core/src/error.rs index 35a33032..c18b43c6 100644 --- a/candle-core/src/error.rs +++ b/candle-core/src/error.rs @@ -185,6 +185,13 @@ pub enum Error { #[error(transparent)] Wrapped(Box), + /// Adding path information to an error. + #[error("path: {path:?} {inner}")] + WithPath { + inner: Box, + path: std::path::PathBuf, + }, + #[error("{inner}\n{backtrace}")] WithBacktrace { inner: Box, @@ -214,6 +221,13 @@ impl Error { }, } } + + pub fn with_path>(self, p: P) -> Self { + Self::WithPath { + inner: Box::new(self), + path: p.as_ref().to_path_buf(), + } + } } #[macro_export] diff --git a/candle-core/src/npy.rs b/candle-core/src/npy.rs index 6302cf71..e17ba02a 100644 --- a/candle-core/src/npy.rs +++ b/candle-core/src/npy.rs @@ -307,39 +307,39 @@ impl Tensor { header.push('\n'); f.write_all(&[(header.len() % 256) as u8, (header.len() / 256) as u8])?; f.write_all(header.as_bytes())?; - let elem_count = self.elem_count(); + let vs = self.flatten_all()?; match self.dtype() { DType::BF16 => { - let vs = self.reshape(elem_count)?.to_vec1::()?; + let vs = vs.to_vec1::()?; for &v in vs.reinterpret_cast() { f.write_u16::(v)? } } DType::F16 => { - let vs = self.reshape(elem_count)?.to_vec1::()?; + let vs = vs.to_vec1::()?; for &v in vs.reinterpret_cast() { f.write_u16::(v)? } } DType::F32 => { // TODO: Avoid using a buffer when data is already on the CPU. - for v in self.reshape(elem_count)?.to_vec1::()? { + for v in vs.to_vec1::()? { f.write_f32::(v)? } } DType::F64 => { - for v in self.reshape(elem_count)?.to_vec1::()? { + for v in vs.to_vec1::()? { f.write_f64::(v)? } } DType::U32 => { - for v in self.reshape(elem_count)?.to_vec1::()? { + for v in vs.to_vec1::()? { f.write_u32::(v)? } } DType::U8 => { - let data = self.reshape(elem_count)?.to_vec1::()?; - f.write_all(&data)?; + let vs = vs.to_vec1::()?; + f.write_all(&vs)?; } } Ok(()) @@ -373,7 +373,7 @@ pub struct NpzTensors { index_per_name: HashMap, path: std::path::PathBuf, // We do not store a zip reader as it needs mutable access to extract data. Instead we - // re-create a zip reader each time. + // re-create a zip reader for each tensor. } impl NpzTensors { diff --git a/candle-core/src/safetensors.rs b/candle-core/src/safetensors.rs index 132fb914..914e5101 100644 --- a/candle-core/src/safetensors.rs +++ b/candle-core/src/safetensors.rs @@ -257,7 +257,10 @@ pub fn save>(tensors: &HashMap<&str, Tensor>, filename: P) -> Res Ok(st::serialize_to_file(tensors, &None, filename.as_ref())?) } -pub struct MmapedFile(memmap2::Mmap); +pub struct MmapedFile { + path: std::path::PathBuf, + inner: memmap2::Mmap, +} impl MmapedFile { /// Creates a wrapper around a memory mapped file from which you can retrieve @@ -267,13 +270,20 @@ impl MmapedFile { /// /// The unsafe is inherited from [`memmap2::MmapOptions`]. pub unsafe fn new>(p: P) -> Result { - let file = std::fs::File::open(p)?; - let mmap = memmap2::MmapOptions::new().map(&file)?; - Ok(Self(mmap)) + let p = p.as_ref(); + let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?; + let inner = memmap2::MmapOptions::new() + .map(&file) + .map_err(|e| Error::from(e).with_path(p))?; + Ok(Self { + inner, + path: p.to_path_buf(), + }) } pub fn deserialize(&self) -> Result> { - let st = safetensors::SafeTensors::deserialize(&self.0)?; + let st = safetensors::SafeTensors::deserialize(&self.inner) + .map_err(|e| Error::from(e).with_path(&self.path))?; Ok(st) } }