Npy tweaks & error with path (#384)

* Simplify the npy writing.

* Wrap the file path so as to provide better errors.
This commit is contained in:
Laurent Mazare
2023-08-10 07:21:58 +02:00
committed by GitHub
parent c7f92f985e
commit f3fe730a30
3 changed files with 38 additions and 14 deletions

View File

@ -185,6 +185,13 @@ pub enum Error {
#[error(transparent)] #[error(transparent)]
Wrapped(Box<dyn std::error::Error + Send + Sync>), Wrapped(Box<dyn std::error::Error + Send + Sync>),
/// Adding path information to an error.
#[error("path: {path:?} {inner}")]
WithPath {
inner: Box<Self>,
path: std::path::PathBuf,
},
#[error("{inner}\n{backtrace}")] #[error("{inner}\n{backtrace}")]
WithBacktrace { WithBacktrace {
inner: Box<Self>, inner: Box<Self>,
@ -214,6 +221,13 @@ impl Error {
}, },
} }
} }
pub fn with_path<P: AsRef<std::path::Path>>(self, p: P) -> Self {
Self::WithPath {
inner: Box::new(self),
path: p.as_ref().to_path_buf(),
}
}
} }
#[macro_export] #[macro_export]

View File

@ -307,39 +307,39 @@ impl Tensor {
header.push('\n'); header.push('\n');
f.write_all(&[(header.len() % 256) as u8, (header.len() / 256) as u8])?; f.write_all(&[(header.len() % 256) as u8, (header.len() / 256) as u8])?;
f.write_all(header.as_bytes())?; f.write_all(header.as_bytes())?;
let elem_count = self.elem_count(); let vs = self.flatten_all()?;
match self.dtype() { match self.dtype() {
DType::BF16 => { DType::BF16 => {
let vs = self.reshape(elem_count)?.to_vec1::<bf16>()?; let vs = vs.to_vec1::<bf16>()?;
for &v in vs.reinterpret_cast() { for &v in vs.reinterpret_cast() {
f.write_u16::<LittleEndian>(v)? f.write_u16::<LittleEndian>(v)?
} }
} }
DType::F16 => { DType::F16 => {
let vs = self.reshape(elem_count)?.to_vec1::<f16>()?; let vs = vs.to_vec1::<f16>()?;
for &v in vs.reinterpret_cast() { for &v in vs.reinterpret_cast() {
f.write_u16::<LittleEndian>(v)? f.write_u16::<LittleEndian>(v)?
} }
} }
DType::F32 => { DType::F32 => {
// TODO: Avoid using a buffer when data is already on the CPU. // TODO: Avoid using a buffer when data is already on the CPU.
for v in self.reshape(elem_count)?.to_vec1::<f32>()? { for v in vs.to_vec1::<f32>()? {
f.write_f32::<LittleEndian>(v)? f.write_f32::<LittleEndian>(v)?
} }
} }
DType::F64 => { DType::F64 => {
for v in self.reshape(elem_count)?.to_vec1::<f64>()? { for v in vs.to_vec1::<f64>()? {
f.write_f64::<LittleEndian>(v)? f.write_f64::<LittleEndian>(v)?
} }
} }
DType::U32 => { DType::U32 => {
for v in self.reshape(elem_count)?.to_vec1::<u32>()? { for v in vs.to_vec1::<u32>()? {
f.write_u32::<LittleEndian>(v)? f.write_u32::<LittleEndian>(v)?
} }
} }
DType::U8 => { DType::U8 => {
let data = self.reshape(elem_count)?.to_vec1::<u8>()?; let vs = vs.to_vec1::<u8>()?;
f.write_all(&data)?; f.write_all(&vs)?;
} }
} }
Ok(()) Ok(())
@ -373,7 +373,7 @@ pub struct NpzTensors {
index_per_name: HashMap<String, usize>, index_per_name: HashMap<String, usize>,
path: std::path::PathBuf, path: std::path::PathBuf,
// We do not store a zip reader as it needs mutable access to extract data. Instead we // We do not store a zip reader as it needs mutable access to extract data. Instead we
// re-create a zip reader each time. // re-create a zip reader for each tensor.
} }
impl NpzTensors { impl NpzTensors {

View File

@ -257,7 +257,10 @@ pub fn save<P: AsRef<Path>>(tensors: &HashMap<&str, Tensor>, filename: P) -> Res
Ok(st::serialize_to_file(tensors, &None, filename.as_ref())?) Ok(st::serialize_to_file(tensors, &None, filename.as_ref())?)
} }
pub struct MmapedFile(memmap2::Mmap); pub struct MmapedFile {
path: std::path::PathBuf,
inner: memmap2::Mmap,
}
impl MmapedFile { impl MmapedFile {
/// Creates a wrapper around a memory mapped file from which you can retrieve /// Creates a wrapper around a memory mapped file from which you can retrieve
@ -267,13 +270,20 @@ impl MmapedFile {
/// ///
/// The unsafe is inherited from [`memmap2::MmapOptions`]. /// The unsafe is inherited from [`memmap2::MmapOptions`].
pub unsafe fn new<P: AsRef<std::path::Path>>(p: P) -> Result<Self> { pub unsafe fn new<P: AsRef<std::path::Path>>(p: P) -> Result<Self> {
let file = std::fs::File::open(p)?; let p = p.as_ref();
let mmap = memmap2::MmapOptions::new().map(&file)?; let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?;
Ok(Self(mmap)) let inner = memmap2::MmapOptions::new()
.map(&file)
.map_err(|e| Error::from(e).with_path(p))?;
Ok(Self {
inner,
path: p.to_path_buf(),
})
} }
pub fn deserialize(&self) -> Result<SafeTensors<'_>> { pub fn deserialize(&self) -> Result<SafeTensors<'_>> {
let st = safetensors::SafeTensors::deserialize(&self.0)?; let st = safetensors::SafeTensors::deserialize(&self.inner)
.map_err(|e| Error::from(e).with_path(&self.path))?;
Ok(st) Ok(st)
} }
} }