Handle f16/bf16 in npy.

This commit is contained in:
laurent
2023-06-26 21:10:03 +01:00
parent de1f612645
commit 7cfa4c307c

View File

@ -27,6 +27,7 @@
//! ```
use crate::{DType, Device, Error, Result, Shape, Tensor};
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use half::{bf16, f16, slice::HalfFloatSliceExt};
use std::collections::HashMap;
use std::fs::File;
use std::io::{BufReader, Read, Write};
@ -80,7 +81,7 @@ impl Header {
.collect::<Vec<_>>()
.join(",");
let descr = match self.descr {
DType::BF16 => todo!("bf16"),
DType::BF16 => Err(Error::Npy("bf16 is not supported".into()))?,
DType::F16 => "f2",
DType::F32 => "f4",
DType::F64 => "f8",
@ -193,14 +194,19 @@ impl Header {
}
impl Tensor {
// TODO: Add the possibility to read directly to a device?
fn from_reader<R: std::io::Read>(shape: Shape, dtype: DType, reader: &mut R) -> Result<Self> {
let elem_count = shape.elem_count();
match dtype {
DType::BF16 => {
todo!("bf16")
let mut data_t = vec![bf16::ZERO; elem_count];
reader.read_u16_into::<LittleEndian>(data_t.reinterpret_cast_mut())?;
Tensor::from_vec(data_t, shape, &Device::Cpu)
}
DType::F16 => {
todo!("f16")
let mut data_t = vec![f16::ZERO; elem_count];
reader.read_u16_into::<LittleEndian>(data_t.reinterpret_cast_mut())?;
Tensor::from_vec(data_t, shape, &Device::Cpu)
}
DType::F32 => {
let mut data_t = vec![0f32; elem_count];
@ -298,10 +304,16 @@ impl Tensor {
let elem_count = self.elem_count();
match self.dtype() {
DType::BF16 => {
todo!("bf16")
let vs = self.reshape(elem_count)?.to_vec1::<bf16>()?;
for &v in vs.reinterpret_cast() {
f.write_u16::<LittleEndian>(v)?
}
}
DType::F16 => {
todo!("f16")
let vs = self.reshape(elem_count)?.to_vec1::<f16>()?;
for &v in vs.reinterpret_cast() {
f.write_u16::<LittleEndian>(v)?
}
}
DType::F32 => {
// TODO: Avoid using a buffer when data is already on the CPU.