mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Removing inner dependency on safetensors.
This commit is contained in:
@ -1,7 +1,6 @@
|
||||
use crate::{DType, Device, Error, Result, Tensor, WithDType};
|
||||
use safetensors::slice::SliceIterator;
|
||||
use safetensors::tensor as st;
|
||||
use safetensors::tensor::{Dtype, SafeTensors};
|
||||
use safetensors::tensor::SafeTensors;
|
||||
use std::borrow::Cow;
|
||||
|
||||
impl From<DType> for st::Dtype {
|
||||
@ -118,26 +117,24 @@ impl<'a> Load for st::TensorView<'a> {
|
||||
}
|
||||
|
||||
impl Tensor {
|
||||
pub fn from_safetensors_slice(
|
||||
iterator: SliceIterator,
|
||||
dtype: Dtype,
|
||||
pub fn from_raw_buffer(
|
||||
data: &[u8],
|
||||
dtype: DType,
|
||||
shape: &[usize],
|
||||
device: &Device,
|
||||
) -> Result<Self> {
|
||||
let data: Vec<u8> = iterator.into_iter().flatten().cloned().collect();
|
||||
match dtype {
|
||||
st::Dtype::U8 => convert_slice::<u8>(&data, shape, device),
|
||||
st::Dtype::U32 => convert_slice::<u8>(&data, shape, device),
|
||||
st::Dtype::BF16 => convert_slice::<half::bf16>(&data, shape, device),
|
||||
st::Dtype::F16 => convert_slice::<half::f16>(&data, shape, device),
|
||||
st::Dtype::F32 => convert_slice::<f32>(&data, shape, device),
|
||||
st::Dtype::F64 => convert_slice::<f64>(&data, shape, device),
|
||||
dtype => Err(Error::UnsupportedSafeTensorDtype(dtype)),
|
||||
DType::U8 => convert_slice::<u8>(data, shape, device),
|
||||
DType::U32 => convert_slice::<u32>(data, shape, device),
|
||||
DType::BF16 => convert_slice::<half::bf16>(data, shape, device),
|
||||
DType::F16 => convert_slice::<half::f16>(data, shape, device),
|
||||
DType::F32 => convert_slice::<f32>(data, shape, device),
|
||||
DType::F64 => convert_slice::<f64>(data, shape, device),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn convert(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> {
|
||||
fn convert(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> {
|
||||
match view.dtype() {
|
||||
st::Dtype::U8 => convert_::<u8>(view, device),
|
||||
st::Dtype::U32 => convert_::<u8>(view, device),
|
||||
@ -149,7 +146,7 @@ pub fn convert(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn convert_back(tensor: &Tensor) -> Result<Vec<u8>> {
|
||||
fn convert_back(tensor: &Tensor) -> Result<Vec<u8>> {
|
||||
// TODO: This makes an unnecessary copy when the tensor is on the cpu.
|
||||
let tensor = tensor.flatten_all()?;
|
||||
match tensor.dtype() {
|
||||
|
Reference in New Issue
Block a user