From 7cfa4c307c00917a24b2e94796c2cb0196ab96d5 Mon Sep 17 00:00:00 2001 From: laurent Date: Mon, 26 Jun 2023 21:10:03 +0100 Subject: [PATCH] Handle f16/bf16 in npy. --- src/npy.rs | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/src/npy.rs b/src/npy.rs index 3eb4e7c1..7e157c8f 100644 --- a/src/npy.rs +++ b/src/npy.rs @@ -27,6 +27,7 @@ //! ``` use crate::{DType, Device, Error, Result, Shape, Tensor}; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; +use half::{bf16, f16, slice::HalfFloatSliceExt}; use std::collections::HashMap; use std::fs::File; use std::io::{BufReader, Read, Write}; @@ -80,7 +81,7 @@ impl Header { .collect::>() .join(","); let descr = match self.descr { - DType::BF16 => todo!("bf16"), + DType::BF16 => Err(Error::Npy("bf16 is not supported".into()))?, DType::F16 => "f2", DType::F32 => "f4", DType::F64 => "f8", @@ -193,14 +194,19 @@ impl Header { } impl Tensor { + // TODO: Add the possibility to read directly to a device? fn from_reader(shape: Shape, dtype: DType, reader: &mut R) -> Result { let elem_count = shape.elem_count(); match dtype { DType::BF16 => { - todo!("bf16") + let mut data_t = vec![bf16::ZERO; elem_count]; + reader.read_u16_into::(data_t.reinterpret_cast_mut())?; + Tensor::from_vec(data_t, shape, &Device::Cpu) } DType::F16 => { - todo!("f16") + let mut data_t = vec![f16::ZERO; elem_count]; + reader.read_u16_into::(data_t.reinterpret_cast_mut())?; + Tensor::from_vec(data_t, shape, &Device::Cpu) } DType::F32 => { let mut data_t = vec![0f32; elem_count]; @@ -298,10 +304,16 @@ impl Tensor { let elem_count = self.elem_count(); match self.dtype() { DType::BF16 => { - todo!("bf16") + let vs = self.reshape(elem_count)?.to_vec1::()?; + for &v in vs.reinterpret_cast() { + f.write_u16::(v)? + } } DType::F16 => { - todo!("f16") + let vs = self.reshape(elem_count)?.to_vec1::()?; + for &v in vs.reinterpret_cast() { + f.write_u16::(v)? + } } DType::F32 => { // TODO: Avoid using a buffer when data is already on the CPU.