mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Expose the tensor write-bytes function. (#412)
This commit is contained in:
@ -1,6 +1,6 @@
|
||||
//! Implement conversion traits for tensors
|
||||
use crate::{Device, Error, Tensor, WithDType};
|
||||
use half::{bf16, f16};
|
||||
use crate::{DType, Device, Error, Tensor, WithDType};
|
||||
use half::{bf16, f16, slice::HalfFloatSliceExt};
|
||||
use std::convert::TryFrom;
|
||||
|
||||
impl<T: WithDType> TryFrom<&Tensor> for Vec<T> {
|
||||
@ -94,3 +94,46 @@ from_tensor!(f16);
|
||||
from_tensor!(bf16);
|
||||
from_tensor!(u32);
|
||||
from_tensor!(u8);
|
||||
|
||||
impl Tensor {
|
||||
pub fn write_bytes<W: std::io::Write>(&self, f: &mut W) -> crate::Result<()> {
|
||||
use byteorder::{LittleEndian, WriteBytesExt};
|
||||
|
||||
let vs = self.flatten_all()?;
|
||||
match self.dtype() {
|
||||
DType::BF16 => {
|
||||
let vs = vs.to_vec1::<bf16>()?;
|
||||
for &v in vs.reinterpret_cast() {
|
||||
f.write_u16::<LittleEndian>(v)?
|
||||
}
|
||||
}
|
||||
DType::F16 => {
|
||||
let vs = vs.to_vec1::<f16>()?;
|
||||
for &v in vs.reinterpret_cast() {
|
||||
f.write_u16::<LittleEndian>(v)?
|
||||
}
|
||||
}
|
||||
DType::F32 => {
|
||||
// TODO: Avoid using a buffer when data is already on the CPU.
|
||||
for v in vs.to_vec1::<f32>()? {
|
||||
f.write_f32::<LittleEndian>(v)?
|
||||
}
|
||||
}
|
||||
DType::F64 => {
|
||||
for v in vs.to_vec1::<f64>()? {
|
||||
f.write_f64::<LittleEndian>(v)?
|
||||
}
|
||||
}
|
||||
DType::U32 => {
|
||||
for v in vs.to_vec1::<u32>()? {
|
||||
f.write_u32::<LittleEndian>(v)?
|
||||
}
|
||||
}
|
||||
DType::U8 => {
|
||||
let vs = vs.to_vec1::<u8>()?;
|
||||
f.write_all(&vs)?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
@ -26,7 +26,7 @@
|
||||
//! values = np.loadz("test.npz")
|
||||
//! ```
|
||||
use crate::{DType, Device, Error, Result, Shape, Tensor};
|
||||
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
|
||||
use byteorder::{LittleEndian, ReadBytesExt};
|
||||
use half::{bf16, f16, slice::HalfFloatSliceExt};
|
||||
use std::collections::HashMap;
|
||||
use std::fs::File;
|
||||
@ -307,42 +307,7 @@ impl Tensor {
|
||||
header.push('\n');
|
||||
f.write_all(&[(header.len() % 256) as u8, (header.len() / 256) as u8])?;
|
||||
f.write_all(header.as_bytes())?;
|
||||
let vs = self.flatten_all()?;
|
||||
match self.dtype() {
|
||||
DType::BF16 => {
|
||||
let vs = vs.to_vec1::<bf16>()?;
|
||||
for &v in vs.reinterpret_cast() {
|
||||
f.write_u16::<LittleEndian>(v)?
|
||||
}
|
||||
}
|
||||
DType::F16 => {
|
||||
let vs = vs.to_vec1::<f16>()?;
|
||||
for &v in vs.reinterpret_cast() {
|
||||
f.write_u16::<LittleEndian>(v)?
|
||||
}
|
||||
}
|
||||
DType::F32 => {
|
||||
// TODO: Avoid using a buffer when data is already on the CPU.
|
||||
for v in vs.to_vec1::<f32>()? {
|
||||
f.write_f32::<LittleEndian>(v)?
|
||||
}
|
||||
}
|
||||
DType::F64 => {
|
||||
for v in vs.to_vec1::<f64>()? {
|
||||
f.write_f64::<LittleEndian>(v)?
|
||||
}
|
||||
}
|
||||
DType::U32 => {
|
||||
for v in vs.to_vec1::<u32>()? {
|
||||
f.write_u32::<LittleEndian>(v)?
|
||||
}
|
||||
}
|
||||
DType::U8 => {
|
||||
let vs = vs.to_vec1::<u8>()?;
|
||||
f.write_all(&vs)?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
self.write_bytes(f)
|
||||
}
|
||||
|
||||
/// Writes a multi-dimensional array in the npy format.
|
||||
|
Reference in New Issue
Block a user