Handle f16/bf16 in npy.

This commit is contained in:
laurent
2023-06-26 21:10:03 +01:00
parent de1f612645
commit 7cfa4c307c

View File

@ -27,6 +27,7 @@
//! ``` //! ```
use crate::{DType, Device, Error, Result, Shape, Tensor}; use crate::{DType, Device, Error, Result, Shape, Tensor};
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use half::{bf16, f16, slice::HalfFloatSliceExt};
use std::collections::HashMap; use std::collections::HashMap;
use std::fs::File; use std::fs::File;
use std::io::{BufReader, Read, Write}; use std::io::{BufReader, Read, Write};
@ -80,7 +81,7 @@ impl Header {
.collect::<Vec<_>>() .collect::<Vec<_>>()
.join(","); .join(",");
let descr = match self.descr { let descr = match self.descr {
DType::BF16 => todo!("bf16"), DType::BF16 => Err(Error::Npy("bf16 is not supported".into()))?,
DType::F16 => "f2", DType::F16 => "f2",
DType::F32 => "f4", DType::F32 => "f4",
DType::F64 => "f8", DType::F64 => "f8",
@ -193,14 +194,19 @@ impl Header {
} }
impl Tensor { impl Tensor {
// TODO: Add the possibility to read directly to a device?
fn from_reader<R: std::io::Read>(shape: Shape, dtype: DType, reader: &mut R) -> Result<Self> { fn from_reader<R: std::io::Read>(shape: Shape, dtype: DType, reader: &mut R) -> Result<Self> {
let elem_count = shape.elem_count(); let elem_count = shape.elem_count();
match dtype { match dtype {
DType::BF16 => { DType::BF16 => {
todo!("bf16") let mut data_t = vec![bf16::ZERO; elem_count];
reader.read_u16_into::<LittleEndian>(data_t.reinterpret_cast_mut())?;
Tensor::from_vec(data_t, shape, &Device::Cpu)
} }
DType::F16 => { DType::F16 => {
todo!("f16") let mut data_t = vec![f16::ZERO; elem_count];
reader.read_u16_into::<LittleEndian>(data_t.reinterpret_cast_mut())?;
Tensor::from_vec(data_t, shape, &Device::Cpu)
} }
DType::F32 => { DType::F32 => {
let mut data_t = vec![0f32; elem_count]; let mut data_t = vec![0f32; elem_count];
@ -298,10 +304,16 @@ impl Tensor {
let elem_count = self.elem_count(); let elem_count = self.elem_count();
match self.dtype() { match self.dtype() {
DType::BF16 => { DType::BF16 => {
todo!("bf16") let vs = self.reshape(elem_count)?.to_vec1::<bf16>()?;
for &v in vs.reinterpret_cast() {
f.write_u16::<LittleEndian>(v)?
}
} }
DType::F16 => { DType::F16 => {
todo!("f16") let vs = self.reshape(elem_count)?.to_vec1::<f16>()?;
for &v in vs.reinterpret_cast() {
f.write_u16::<LittleEndian>(v)?
}
} }
DType::F32 => { DType::F32 => {
// TODO: Avoid using a buffer when data is already on the CPU. // TODO: Avoid using a buffer when data is already on the CPU.