mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Handle f16/bf16 in npy.
This commit is contained in:
22
src/npy.rs
22
src/npy.rs
@ -27,6 +27,7 @@
|
||||
//! ```
|
||||
use crate::{DType, Device, Error, Result, Shape, Tensor};
|
||||
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
|
||||
use half::{bf16, f16, slice::HalfFloatSliceExt};
|
||||
use std::collections::HashMap;
|
||||
use std::fs::File;
|
||||
use std::io::{BufReader, Read, Write};
|
||||
@ -80,7 +81,7 @@ impl Header {
|
||||
.collect::<Vec<_>>()
|
||||
.join(",");
|
||||
let descr = match self.descr {
|
||||
DType::BF16 => todo!("bf16"),
|
||||
DType::BF16 => Err(Error::Npy("bf16 is not supported".into()))?,
|
||||
DType::F16 => "f2",
|
||||
DType::F32 => "f4",
|
||||
DType::F64 => "f8",
|
||||
@ -193,14 +194,19 @@ impl Header {
|
||||
}
|
||||
|
||||
impl Tensor {
|
||||
// TODO: Add the possibility to read directly to a device?
|
||||
fn from_reader<R: std::io::Read>(shape: Shape, dtype: DType, reader: &mut R) -> Result<Self> {
|
||||
let elem_count = shape.elem_count();
|
||||
match dtype {
|
||||
DType::BF16 => {
|
||||
todo!("bf16")
|
||||
let mut data_t = vec![bf16::ZERO; elem_count];
|
||||
reader.read_u16_into::<LittleEndian>(data_t.reinterpret_cast_mut())?;
|
||||
Tensor::from_vec(data_t, shape, &Device::Cpu)
|
||||
}
|
||||
DType::F16 => {
|
||||
todo!("f16")
|
||||
let mut data_t = vec![f16::ZERO; elem_count];
|
||||
reader.read_u16_into::<LittleEndian>(data_t.reinterpret_cast_mut())?;
|
||||
Tensor::from_vec(data_t, shape, &Device::Cpu)
|
||||
}
|
||||
DType::F32 => {
|
||||
let mut data_t = vec![0f32; elem_count];
|
||||
@ -298,10 +304,16 @@ impl Tensor {
|
||||
let elem_count = self.elem_count();
|
||||
match self.dtype() {
|
||||
DType::BF16 => {
|
||||
todo!("bf16")
|
||||
let vs = self.reshape(elem_count)?.to_vec1::<bf16>()?;
|
||||
for &v in vs.reinterpret_cast() {
|
||||
f.write_u16::<LittleEndian>(v)?
|
||||
}
|
||||
}
|
||||
DType::F16 => {
|
||||
todo!("f16")
|
||||
let vs = self.reshape(elem_count)?.to_vec1::<f16>()?;
|
||||
for &v in vs.reinterpret_cast() {
|
||||
f.write_u16::<LittleEndian>(v)?
|
||||
}
|
||||
}
|
||||
DType::F32 => {
|
||||
// TODO: Avoid using a buffer when data is already on the CPU.
|
||||
|
Reference in New Issue
Block a user