Expose the tensor write-bytes function. (#412)

This commit is contained in:
Laurent Mazare
2023-08-11 18:13:42 +02:00
committed by GitHub
parent 55e428c8ae
commit e12372021b
2 changed files with 47 additions and 39 deletions

View File

@ -1,6 +1,6 @@
//! Implement conversion traits for tensors
use crate::{Device, Error, Tensor, WithDType};
use half::{bf16, f16};
use crate::{DType, Device, Error, Tensor, WithDType};
use half::{bf16, f16, slice::HalfFloatSliceExt};
use std::convert::TryFrom;
impl<T: WithDType> TryFrom<&Tensor> for Vec<T> {
@ -94,3 +94,46 @@ from_tensor!(f16);
from_tensor!(bf16);
from_tensor!(u32);
from_tensor!(u8);
impl Tensor {
pub fn write_bytes<W: std::io::Write>(&self, f: &mut W) -> crate::Result<()> {
use byteorder::{LittleEndian, WriteBytesExt};
let vs = self.flatten_all()?;
match self.dtype() {
DType::BF16 => {
let vs = vs.to_vec1::<bf16>()?;
for &v in vs.reinterpret_cast() {
f.write_u16::<LittleEndian>(v)?
}
}
DType::F16 => {
let vs = vs.to_vec1::<f16>()?;
for &v in vs.reinterpret_cast() {
f.write_u16::<LittleEndian>(v)?
}
}
DType::F32 => {
// TODO: Avoid using a buffer when data is already on the CPU.
for v in vs.to_vec1::<f32>()? {
f.write_f32::<LittleEndian>(v)?
}
}
DType::F64 => {
for v in vs.to_vec1::<f64>()? {
f.write_f64::<LittleEndian>(v)?
}
}
DType::U32 => {
for v in vs.to_vec1::<u32>()? {
f.write_u32::<LittleEndian>(v)?
}
}
DType::U8 => {
let vs = vs.to_vec1::<u8>()?;
f.write_all(&vs)?;
}
}
Ok(())
}
}

View File

@ -26,7 +26,7 @@
//! values = np.loadz("test.npz")
//! ```
use crate::{DType, Device, Error, Result, Shape, Tensor};
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use byteorder::{LittleEndian, ReadBytesExt};
use half::{bf16, f16, slice::HalfFloatSliceExt};
use std::collections::HashMap;
use std::fs::File;
@ -307,42 +307,7 @@ impl Tensor {
header.push('\n');
f.write_all(&[(header.len() % 256) as u8, (header.len() / 256) as u8])?;
f.write_all(header.as_bytes())?;
let vs = self.flatten_all()?;
match self.dtype() {
DType::BF16 => {
let vs = vs.to_vec1::<bf16>()?;
for &v in vs.reinterpret_cast() {
f.write_u16::<LittleEndian>(v)?
}
}
DType::F16 => {
let vs = vs.to_vec1::<f16>()?;
for &v in vs.reinterpret_cast() {
f.write_u16::<LittleEndian>(v)?
}
}
DType::F32 => {
// TODO: Avoid using a buffer when data is already on the CPU.
for v in vs.to_vec1::<f32>()? {
f.write_f32::<LittleEndian>(v)?
}
}
DType::F64 => {
for v in vs.to_vec1::<f64>()? {
f.write_f64::<LittleEndian>(v)?
}
}
DType::U32 => {
for v in vs.to_vec1::<u32>()? {
f.write_u32::<LittleEndian>(v)?
}
}
DType::U8 => {
let vs = vs.to_vec1::<u8>()?;
f.write_all(&vs)?;
}
}
Ok(())
self.write_bytes(f)
}
/// Writes a multi-dimensional array in the npy format.