Removing inner dependency on safetensors.

This commit is contained in:
Nicolas Patry
2023-07-26 11:16:04 +02:00
parent 1553b58fe5
commit 7c7e6ba201
4 changed files with 30 additions and 32 deletions

View File

@ -1,6 +1,5 @@
use candle::{safetensors::Load, DType, Device, Error, Result, Shape, Tensor};
use safetensors::slice::IndexOp;
use safetensors::tensor::SafeTensors;
use safetensors::{slice::IndexOp, tensor::SafeTensors};
use std::collections::HashMap;
use std::sync::Arc;
@ -70,7 +69,7 @@ impl<'a> TensorData<'a> {
#[derive(Clone)]
pub struct VarBuilder<'a> {
data: Arc<TensorData<'a>>,
pub path: Vec<String>,
path: Vec<String>,
}
impl<'a> VarBuilder<'a> {
@ -179,7 +178,10 @@ impl<'a> VarBuilder<'a> {
shape[dim] = block_size;
Tensor::from_safetensors_slice(iterator, dtype, &shape, &data.device)?
let dtype: DType = dtype.try_into()?;
let raw: Vec<u8> = iterator.into_iter().flatten().cloned().collect();
Tensor::from_raw_buffer(&raw, dtype, &shape, &data.device)?
}
_ => unimplemented!(),
};