Handle more types in safetensors.

This commit is contained in:
laurent
2023-07-03 10:09:46 +01:00
parent 783b7054ee
commit 899c76de75
2 changed files with 30 additions and 21 deletions

View File

@ -114,6 +114,9 @@ pub enum Error {
#[error(transparent)] #[error(transparent)]
SafeTensor(#[from] safetensors::SafeTensorError), SafeTensor(#[from] safetensors::SafeTensorError),
#[error("unsupported safetensor dtype {0:?}")]
UnsupportedSafeTensorDtype(safetensors::Dtype),
#[error("cannot broadcast {src_shape:?} to {dst_shape:?}")] #[error("cannot broadcast {src_shape:?} to {dst_shape:?}")]
BroadcastIncompatibleShapes { src_shape: Shape, dst_shape: Shape }, BroadcastIncompatibleShapes { src_shape: Shape, dst_shape: Shape },
} }

View File

@ -1,28 +1,34 @@
use crate::{Device, Result, Tensor}; use crate::{Device, Error, Result, Tensor, WithDType};
use half::f16;
use safetensors::tensor as st; use safetensors::tensor as st;
pub fn convert(view: st::TensorView<'_>, device: &Device) -> Result<Tensor> { fn convert_<T: WithDType>(view: st::TensorView<'_>, device: &Device) -> Result<Tensor> {
match view.dtype() {
st::Dtype::F16 => {
let v = view.data(); let v = view.data();
if (v.as_ptr() as usize) % 2 == 0 { let size_in_bytes = T::DTYPE.size_in_bytes();
let elem_count = v.len() / size_in_bytes;
if (v.as_ptr() as usize) % size_in_bytes == 0 {
// SAFETY This is safe because we just checked that this // SAFETY This is safe because we just checked that this
// was correctly aligned. // was correctly aligned.
let data: &[f16] = let data: &[T] = unsafe { std::slice::from_raw_parts(v.as_ptr() as *const T, elem_count) };
unsafe { std::slice::from_raw_parts(v.as_ptr() as *const f16, v.len() / 2) };
Tensor::from_slice(data, view.shape(), device) Tensor::from_slice(data, view.shape(), device)
} else { } else {
let mut c = Vec::with_capacity(v.len() / 2); let mut c = Vec::with_capacity(elem_count);
let mut i = 0; unsafe {
while i < v.len() { std::ptr::copy_nonoverlapping(v.as_ptr(), c.as_mut_ptr() as *mut u8, v.len());
c.push(f16::from_le_bytes([v[i], v[i + 1]])); c.set_len(elem_count)
i += 2;
} }
Tensor::from_slice(&c, view.shape(), device) Tensor::from_slice(&c, view.shape(), device)
} }
} }
dt => todo!("Unhandled dtype {dt:?}"),
pub fn convert(view: st::TensorView<'_>, device: &Device) -> Result<Tensor> {
match view.dtype() {
st::Dtype::U8 => convert_::<u8>(view, device),
st::Dtype::U32 => convert_::<u8>(view, device),
st::Dtype::BF16 => convert_::<half::bf16>(view, device),
st::Dtype::F16 => convert_::<half::f16>(view, device),
st::Dtype::F32 => convert_::<f32>(view, device),
st::Dtype::F64 => convert_::<f64>(view, device),
dtype => Err(Error::UnsupportedSafeTensorDtype(dtype)),
} }
} }