mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Expose the tensor write-bytes function. (#412)
This commit is contained in:
@ -1,6 +1,6 @@
|
|||||||
//! Implement conversion traits for tensors
|
//! Implement conversion traits for tensors
|
||||||
use crate::{Device, Error, Tensor, WithDType};
|
use crate::{DType, Device, Error, Tensor, WithDType};
|
||||||
use half::{bf16, f16};
|
use half::{bf16, f16, slice::HalfFloatSliceExt};
|
||||||
use std::convert::TryFrom;
|
use std::convert::TryFrom;
|
||||||
|
|
||||||
impl<T: WithDType> TryFrom<&Tensor> for Vec<T> {
|
impl<T: WithDType> TryFrom<&Tensor> for Vec<T> {
|
||||||
@ -94,3 +94,46 @@ from_tensor!(f16);
|
|||||||
from_tensor!(bf16);
|
from_tensor!(bf16);
|
||||||
from_tensor!(u32);
|
from_tensor!(u32);
|
||||||
from_tensor!(u8);
|
from_tensor!(u8);
|
||||||
|
|
||||||
|
impl Tensor {
|
||||||
|
pub fn write_bytes<W: std::io::Write>(&self, f: &mut W) -> crate::Result<()> {
|
||||||
|
use byteorder::{LittleEndian, WriteBytesExt};
|
||||||
|
|
||||||
|
let vs = self.flatten_all()?;
|
||||||
|
match self.dtype() {
|
||||||
|
DType::BF16 => {
|
||||||
|
let vs = vs.to_vec1::<bf16>()?;
|
||||||
|
for &v in vs.reinterpret_cast() {
|
||||||
|
f.write_u16::<LittleEndian>(v)?
|
||||||
|
}
|
||||||
|
}
|
||||||
|
DType::F16 => {
|
||||||
|
let vs = vs.to_vec1::<f16>()?;
|
||||||
|
for &v in vs.reinterpret_cast() {
|
||||||
|
f.write_u16::<LittleEndian>(v)?
|
||||||
|
}
|
||||||
|
}
|
||||||
|
DType::F32 => {
|
||||||
|
// TODO: Avoid using a buffer when data is already on the CPU.
|
||||||
|
for v in vs.to_vec1::<f32>()? {
|
||||||
|
f.write_f32::<LittleEndian>(v)?
|
||||||
|
}
|
||||||
|
}
|
||||||
|
DType::F64 => {
|
||||||
|
for v in vs.to_vec1::<f64>()? {
|
||||||
|
f.write_f64::<LittleEndian>(v)?
|
||||||
|
}
|
||||||
|
}
|
||||||
|
DType::U32 => {
|
||||||
|
for v in vs.to_vec1::<u32>()? {
|
||||||
|
f.write_u32::<LittleEndian>(v)?
|
||||||
|
}
|
||||||
|
}
|
||||||
|
DType::U8 => {
|
||||||
|
let vs = vs.to_vec1::<u8>()?;
|
||||||
|
f.write_all(&vs)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -26,7 +26,7 @@
|
|||||||
//! values = np.loadz("test.npz")
|
//! values = np.loadz("test.npz")
|
||||||
//! ```
|
//! ```
|
||||||
use crate::{DType, Device, Error, Result, Shape, Tensor};
|
use crate::{DType, Device, Error, Result, Shape, Tensor};
|
||||||
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
|
use byteorder::{LittleEndian, ReadBytesExt};
|
||||||
use half::{bf16, f16, slice::HalfFloatSliceExt};
|
use half::{bf16, f16, slice::HalfFloatSliceExt};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::fs::File;
|
use std::fs::File;
|
||||||
@ -307,42 +307,7 @@ impl Tensor {
|
|||||||
header.push('\n');
|
header.push('\n');
|
||||||
f.write_all(&[(header.len() % 256) as u8, (header.len() / 256) as u8])?;
|
f.write_all(&[(header.len() % 256) as u8, (header.len() / 256) as u8])?;
|
||||||
f.write_all(header.as_bytes())?;
|
f.write_all(header.as_bytes())?;
|
||||||
let vs = self.flatten_all()?;
|
self.write_bytes(f)
|
||||||
match self.dtype() {
|
|
||||||
DType::BF16 => {
|
|
||||||
let vs = vs.to_vec1::<bf16>()?;
|
|
||||||
for &v in vs.reinterpret_cast() {
|
|
||||||
f.write_u16::<LittleEndian>(v)?
|
|
||||||
}
|
|
||||||
}
|
|
||||||
DType::F16 => {
|
|
||||||
let vs = vs.to_vec1::<f16>()?;
|
|
||||||
for &v in vs.reinterpret_cast() {
|
|
||||||
f.write_u16::<LittleEndian>(v)?
|
|
||||||
}
|
|
||||||
}
|
|
||||||
DType::F32 => {
|
|
||||||
// TODO: Avoid using a buffer when data is already on the CPU.
|
|
||||||
for v in vs.to_vec1::<f32>()? {
|
|
||||||
f.write_f32::<LittleEndian>(v)?
|
|
||||||
}
|
|
||||||
}
|
|
||||||
DType::F64 => {
|
|
||||||
for v in vs.to_vec1::<f64>()? {
|
|
||||||
f.write_f64::<LittleEndian>(v)?
|
|
||||||
}
|
|
||||||
}
|
|
||||||
DType::U32 => {
|
|
||||||
for v in vs.to_vec1::<u32>()? {
|
|
||||||
f.write_u32::<LittleEndian>(v)?
|
|
||||||
}
|
|
||||||
}
|
|
||||||
DType::U8 => {
|
|
||||||
let vs = vs.to_vec1::<u8>()?;
|
|
||||||
f.write_all(&vs)?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Writes a multi-dimensional array in the npy format.
|
/// Writes a multi-dimensional array in the npy format.
|
||||||
|
Reference in New Issue
Block a user