mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Removing inner dependency on safetensors.
This commit is contained in:
@ -1,6 +1,5 @@
|
||||
use candle::{safetensors::Load, DType, Device, Error, Result, Shape, Tensor};
|
||||
use safetensors::slice::IndexOp;
|
||||
use safetensors::tensor::SafeTensors;
|
||||
use safetensors::{slice::IndexOp, tensor::SafeTensors};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
@ -70,7 +69,7 @@ impl<'a> TensorData<'a> {
|
||||
#[derive(Clone)]
|
||||
pub struct VarBuilder<'a> {
|
||||
data: Arc<TensorData<'a>>,
|
||||
pub path: Vec<String>,
|
||||
path: Vec<String>,
|
||||
}
|
||||
|
||||
impl<'a> VarBuilder<'a> {
|
||||
@ -179,7 +178,10 @@ impl<'a> VarBuilder<'a> {
|
||||
|
||||
shape[dim] = block_size;
|
||||
|
||||
Tensor::from_safetensors_slice(iterator, dtype, &shape, &data.device)?
|
||||
let dtype: DType = dtype.try_into()?;
|
||||
|
||||
let raw: Vec<u8> = iterator.into_iter().flatten().cloned().collect();
|
||||
Tensor::from_raw_buffer(&raw, dtype, &shape, &data.device)?
|
||||
}
|
||||
_ => unimplemented!(),
|
||||
};
|
||||
|
Reference in New Issue
Block a user