From 899c76de7567572f522f8711e10150de8e4e0d6f Mon Sep 17 00:00:00 2001 From: laurent Date: Mon, 3 Jul 2023 10:09:46 +0100 Subject: [PATCH] Handle more types in safetensors. --- candle-core/src/error.rs | 3 +++ candle-core/src/safetensors.rs | 48 +++++++++++++++++++--------------- 2 files changed, 30 insertions(+), 21 deletions(-) diff --git a/candle-core/src/error.rs b/candle-core/src/error.rs index 71fd21de..d5de4296 100644 --- a/candle-core/src/error.rs +++ b/candle-core/src/error.rs @@ -114,6 +114,9 @@ pub enum Error { #[error(transparent)] SafeTensor(#[from] safetensors::SafeTensorError), + #[error("unsupported safetensor dtype {0:?}")] + UnsupportedSafeTensorDtype(safetensors::Dtype), + #[error("cannot broadcast {src_shape:?} to {dst_shape:?}")] BroadcastIncompatibleShapes { src_shape: Shape, dst_shape: Shape }, } diff --git a/candle-core/src/safetensors.rs b/candle-core/src/safetensors.rs index b586b544..b80a756a 100644 --- a/candle-core/src/safetensors.rs +++ b/candle-core/src/safetensors.rs @@ -1,28 +1,34 @@ -use crate::{Device, Result, Tensor}; -use half::f16; +use crate::{Device, Error, Result, Tensor, WithDType}; use safetensors::tensor as st; +fn convert_(view: st::TensorView<'_>, device: &Device) -> Result { + let v = view.data(); + let size_in_bytes = T::DTYPE.size_in_bytes(); + let elem_count = v.len() / size_in_bytes; + if (v.as_ptr() as usize) % size_in_bytes == 0 { + // SAFETY This is safe because we just checked that this + // was correctly aligned. + let data: &[T] = unsafe { std::slice::from_raw_parts(v.as_ptr() as *const T, elem_count) }; + Tensor::from_slice(data, view.shape(), device) + } else { + let mut c = Vec::with_capacity(elem_count); + unsafe { + std::ptr::copy_nonoverlapping(v.as_ptr(), c.as_mut_ptr() as *mut u8, v.len()); + c.set_len(elem_count) + } + Tensor::from_slice(&c, view.shape(), device) + } +} + pub fn convert(view: st::TensorView<'_>, device: &Device) -> Result { match view.dtype() { - st::Dtype::F16 => { - let v = view.data(); - if (v.as_ptr() as usize) % 2 == 0 { - // SAFETY This is safe because we just checked that this - // was correctly aligned. - let data: &[f16] = - unsafe { std::slice::from_raw_parts(v.as_ptr() as *const f16, v.len() / 2) }; - Tensor::from_slice(data, view.shape(), device) - } else { - let mut c = Vec::with_capacity(v.len() / 2); - let mut i = 0; - while i < v.len() { - c.push(f16::from_le_bytes([v[i], v[i + 1]])); - i += 2; - } - Tensor::from_slice(&c, view.shape(), device) - } - } - dt => todo!("Unhandled dtype {dt:?}"), + st::Dtype::U8 => convert_::(view, device), + st::Dtype::U32 => convert_::(view, device), + st::Dtype::BF16 => convert_::(view, device), + st::Dtype::F16 => convert_::(view, device), + st::Dtype::F32 => convert_::(view, device), + st::Dtype::F64 => convert_::(view, device), + dtype => Err(Error::UnsupportedSafeTensorDtype(dtype)), } }