mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Handle more types in safetensors.
This commit is contained in:
@ -114,6 +114,9 @@ pub enum Error {
|
|||||||
#[error(transparent)]
|
#[error(transparent)]
|
||||||
SafeTensor(#[from] safetensors::SafeTensorError),
|
SafeTensor(#[from] safetensors::SafeTensorError),
|
||||||
|
|
||||||
|
#[error("unsupported safetensor dtype {0:?}")]
|
||||||
|
UnsupportedSafeTensorDtype(safetensors::Dtype),
|
||||||
|
|
||||||
#[error("cannot broadcast {src_shape:?} to {dst_shape:?}")]
|
#[error("cannot broadcast {src_shape:?} to {dst_shape:?}")]
|
||||||
BroadcastIncompatibleShapes { src_shape: Shape, dst_shape: Shape },
|
BroadcastIncompatibleShapes { src_shape: Shape, dst_shape: Shape },
|
||||||
}
|
}
|
||||||
|
@ -1,28 +1,34 @@
|
|||||||
use crate::{Device, Result, Tensor};
|
use crate::{Device, Error, Result, Tensor, WithDType};
|
||||||
use half::f16;
|
|
||||||
use safetensors::tensor as st;
|
use safetensors::tensor as st;
|
||||||
|
|
||||||
pub fn convert(view: st::TensorView<'_>, device: &Device) -> Result<Tensor> {
|
fn convert_<T: WithDType>(view: st::TensorView<'_>, device: &Device) -> Result<Tensor> {
|
||||||
match view.dtype() {
|
|
||||||
st::Dtype::F16 => {
|
|
||||||
let v = view.data();
|
let v = view.data();
|
||||||
if (v.as_ptr() as usize) % 2 == 0 {
|
let size_in_bytes = T::DTYPE.size_in_bytes();
|
||||||
|
let elem_count = v.len() / size_in_bytes;
|
||||||
|
if (v.as_ptr() as usize) % size_in_bytes == 0 {
|
||||||
// SAFETY This is safe because we just checked that this
|
// SAFETY This is safe because we just checked that this
|
||||||
// was correctly aligned.
|
// was correctly aligned.
|
||||||
let data: &[f16] =
|
let data: &[T] = unsafe { std::slice::from_raw_parts(v.as_ptr() as *const T, elem_count) };
|
||||||
unsafe { std::slice::from_raw_parts(v.as_ptr() as *const f16, v.len() / 2) };
|
|
||||||
Tensor::from_slice(data, view.shape(), device)
|
Tensor::from_slice(data, view.shape(), device)
|
||||||
} else {
|
} else {
|
||||||
let mut c = Vec::with_capacity(v.len() / 2);
|
let mut c = Vec::with_capacity(elem_count);
|
||||||
let mut i = 0;
|
unsafe {
|
||||||
while i < v.len() {
|
std::ptr::copy_nonoverlapping(v.as_ptr(), c.as_mut_ptr() as *mut u8, v.len());
|
||||||
c.push(f16::from_le_bytes([v[i], v[i + 1]]));
|
c.set_len(elem_count)
|
||||||
i += 2;
|
|
||||||
}
|
}
|
||||||
Tensor::from_slice(&c, view.shape(), device)
|
Tensor::from_slice(&c, view.shape(), device)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
dt => todo!("Unhandled dtype {dt:?}"),
|
|
||||||
|
pub fn convert(view: st::TensorView<'_>, device: &Device) -> Result<Tensor> {
|
||||||
|
match view.dtype() {
|
||||||
|
st::Dtype::U8 => convert_::<u8>(view, device),
|
||||||
|
st::Dtype::U32 => convert_::<u8>(view, device),
|
||||||
|
st::Dtype::BF16 => convert_::<half::bf16>(view, device),
|
||||||
|
st::Dtype::F16 => convert_::<half::f16>(view, device),
|
||||||
|
st::Dtype::F32 => convert_::<f32>(view, device),
|
||||||
|
st::Dtype::F64 => convert_::<f64>(view, device),
|
||||||
|
dtype => Err(Error::UnsupportedSafeTensorDtype(dtype)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user