diff --git a/candle-core/src/convert.rs b/candle-core/src/convert.rs index 41a9c4ee..744982fc 100644 --- a/candle-core/src/convert.rs +++ b/candle-core/src/convert.rs @@ -1,6 +1,6 @@ //! Implement conversion traits for tensors -use crate::{Device, Error, Tensor, WithDType}; -use half::{bf16, f16}; +use crate::{DType, Device, Error, Tensor, WithDType}; +use half::{bf16, f16, slice::HalfFloatSliceExt}; use std::convert::TryFrom; impl TryFrom<&Tensor> for Vec { @@ -94,3 +94,46 @@ from_tensor!(f16); from_tensor!(bf16); from_tensor!(u32); from_tensor!(u8); + +impl Tensor { + pub fn write_bytes(&self, f: &mut W) -> crate::Result<()> { + use byteorder::{LittleEndian, WriteBytesExt}; + + let vs = self.flatten_all()?; + match self.dtype() { + DType::BF16 => { + let vs = vs.to_vec1::()?; + for &v in vs.reinterpret_cast() { + f.write_u16::(v)? + } + } + DType::F16 => { + let vs = vs.to_vec1::()?; + for &v in vs.reinterpret_cast() { + f.write_u16::(v)? + } + } + DType::F32 => { + // TODO: Avoid using a buffer when data is already on the CPU. + for v in vs.to_vec1::()? { + f.write_f32::(v)? + } + } + DType::F64 => { + for v in vs.to_vec1::()? { + f.write_f64::(v)? + } + } + DType::U32 => { + for v in vs.to_vec1::()? { + f.write_u32::(v)? + } + } + DType::U8 => { + let vs = vs.to_vec1::()?; + f.write_all(&vs)?; + } + } + Ok(()) + } +} diff --git a/candle-core/src/npy.rs b/candle-core/src/npy.rs index e17ba02a..2e394b06 100644 --- a/candle-core/src/npy.rs +++ b/candle-core/src/npy.rs @@ -26,7 +26,7 @@ //! values = np.loadz("test.npz") //! ``` use crate::{DType, Device, Error, Result, Shape, Tensor}; -use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; +use byteorder::{LittleEndian, ReadBytesExt}; use half::{bf16, f16, slice::HalfFloatSliceExt}; use std::collections::HashMap; use std::fs::File; @@ -307,42 +307,7 @@ impl Tensor { header.push('\n'); f.write_all(&[(header.len() % 256) as u8, (header.len() / 256) as u8])?; f.write_all(header.as_bytes())?; - let vs = self.flatten_all()?; - match self.dtype() { - DType::BF16 => { - let vs = vs.to_vec1::()?; - for &v in vs.reinterpret_cast() { - f.write_u16::(v)? - } - } - DType::F16 => { - let vs = vs.to_vec1::()?; - for &v in vs.reinterpret_cast() { - f.write_u16::(v)? - } - } - DType::F32 => { - // TODO: Avoid using a buffer when data is already on the CPU. - for v in vs.to_vec1::()? { - f.write_f32::(v)? - } - } - DType::F64 => { - for v in vs.to_vec1::()? { - f.write_f64::(v)? - } - } - DType::U32 => { - for v in vs.to_vec1::()? { - f.write_u32::(v)? - } - } - DType::U8 => { - let vs = vs.to_vec1::()?; - f.write_all(&vs)?; - } - } - Ok(()) + self.write_bytes(f) } /// Writes a multi-dimensional array in the npy format.