mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Handle f16/bf16 in npy.
This commit is contained in:
22
src/npy.rs
22
src/npy.rs
@ -27,6 +27,7 @@
|
|||||||
//! ```
|
//! ```
|
||||||
use crate::{DType, Device, Error, Result, Shape, Tensor};
|
use crate::{DType, Device, Error, Result, Shape, Tensor};
|
||||||
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
|
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
|
||||||
|
use half::{bf16, f16, slice::HalfFloatSliceExt};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::fs::File;
|
use std::fs::File;
|
||||||
use std::io::{BufReader, Read, Write};
|
use std::io::{BufReader, Read, Write};
|
||||||
@ -80,7 +81,7 @@ impl Header {
|
|||||||
.collect::<Vec<_>>()
|
.collect::<Vec<_>>()
|
||||||
.join(",");
|
.join(",");
|
||||||
let descr = match self.descr {
|
let descr = match self.descr {
|
||||||
DType::BF16 => todo!("bf16"),
|
DType::BF16 => Err(Error::Npy("bf16 is not supported".into()))?,
|
||||||
DType::F16 => "f2",
|
DType::F16 => "f2",
|
||||||
DType::F32 => "f4",
|
DType::F32 => "f4",
|
||||||
DType::F64 => "f8",
|
DType::F64 => "f8",
|
||||||
@ -193,14 +194,19 @@ impl Header {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Tensor {
|
impl Tensor {
|
||||||
|
// TODO: Add the possibility to read directly to a device?
|
||||||
fn from_reader<R: std::io::Read>(shape: Shape, dtype: DType, reader: &mut R) -> Result<Self> {
|
fn from_reader<R: std::io::Read>(shape: Shape, dtype: DType, reader: &mut R) -> Result<Self> {
|
||||||
let elem_count = shape.elem_count();
|
let elem_count = shape.elem_count();
|
||||||
match dtype {
|
match dtype {
|
||||||
DType::BF16 => {
|
DType::BF16 => {
|
||||||
todo!("bf16")
|
let mut data_t = vec![bf16::ZERO; elem_count];
|
||||||
|
reader.read_u16_into::<LittleEndian>(data_t.reinterpret_cast_mut())?;
|
||||||
|
Tensor::from_vec(data_t, shape, &Device::Cpu)
|
||||||
}
|
}
|
||||||
DType::F16 => {
|
DType::F16 => {
|
||||||
todo!("f16")
|
let mut data_t = vec![f16::ZERO; elem_count];
|
||||||
|
reader.read_u16_into::<LittleEndian>(data_t.reinterpret_cast_mut())?;
|
||||||
|
Tensor::from_vec(data_t, shape, &Device::Cpu)
|
||||||
}
|
}
|
||||||
DType::F32 => {
|
DType::F32 => {
|
||||||
let mut data_t = vec![0f32; elem_count];
|
let mut data_t = vec![0f32; elem_count];
|
||||||
@ -298,10 +304,16 @@ impl Tensor {
|
|||||||
let elem_count = self.elem_count();
|
let elem_count = self.elem_count();
|
||||||
match self.dtype() {
|
match self.dtype() {
|
||||||
DType::BF16 => {
|
DType::BF16 => {
|
||||||
todo!("bf16")
|
let vs = self.reshape(elem_count)?.to_vec1::<bf16>()?;
|
||||||
|
for &v in vs.reinterpret_cast() {
|
||||||
|
f.write_u16::<LittleEndian>(v)?
|
||||||
|
}
|
||||||
}
|
}
|
||||||
DType::F16 => {
|
DType::F16 => {
|
||||||
todo!("f16")
|
let vs = self.reshape(elem_count)?.to_vec1::<f16>()?;
|
||||||
|
for &v in vs.reinterpret_cast() {
|
||||||
|
f.write_u16::<LittleEndian>(v)?
|
||||||
|
}
|
||||||
}
|
}
|
||||||
DType::F32 => {
|
DType::F32 => {
|
||||||
// TODO: Avoid using a buffer when data is already on the CPU.
|
// TODO: Avoid using a buffer when data is already on the CPU.
|
||||||
|
Reference in New Issue
Block a user