mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Move some safetensors bits in the candle-core crate.
This commit is contained in:
@ -1,38 +1,10 @@
|
||||
use super::*;
|
||||
use candle::{Device, Result, Tensor};
|
||||
use half::f16;
|
||||
use memmap2::MmapOptions;
|
||||
use safetensors::{
|
||||
tensor::{Dtype, TensorView},
|
||||
SafeTensors,
|
||||
};
|
||||
use safetensors::SafeTensors;
|
||||
use std::fs::File;
|
||||
use std::path::PathBuf;
|
||||
|
||||
fn convert(view: TensorView<'_>, device: &Device) -> Result<Tensor> {
|
||||
match view.dtype() {
|
||||
Dtype::F16 => {
|
||||
let v = view.data();
|
||||
if (v.as_ptr() as usize) % 2 == 0 {
|
||||
// SAFETY This is safe because we just checked that this
|
||||
// was correctly aligned.
|
||||
let data: &[f16] =
|
||||
unsafe { std::slice::from_raw_parts(v.as_ptr() as *const f16, v.len() / 2) };
|
||||
Tensor::from_slice(data, view.shape(), device)?.to_dtype(DTYPE)
|
||||
} else {
|
||||
let mut c = Vec::with_capacity(v.len() / 2);
|
||||
let mut i = 0;
|
||||
while i < v.len() {
|
||||
c.push(f16::from_le_bytes([v[i], v[i + 1]]));
|
||||
i += 2;
|
||||
}
|
||||
Tensor::from_slice(&c, view.shape(), device)?.to_dtype(DTYPE)
|
||||
}
|
||||
}
|
||||
dt => todo!("Unhandled dtype {dt:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
pub struct VarBuilder<'a> {
|
||||
routing: HashMap<String, usize>,
|
||||
safetensors: Vec<SafeTensors<'a>>,
|
||||
@ -59,8 +31,7 @@ impl<'a> VarBuilder<'a> {
|
||||
// Unwrap or 0 just to let the proper error flow.
|
||||
let index = self.routing.get(tensor_name).unwrap_or(&0);
|
||||
let view = self.safetensors[*index].tensor(tensor_name).unwrap();
|
||||
let tensor = convert(view, &self.device)?;
|
||||
Ok(tensor)
|
||||
candle::safetensors::convert(view, &self.device)?.to_dtype(DTYPE)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -10,6 +10,7 @@ mod error;
|
||||
mod layout;
|
||||
mod npy;
|
||||
mod op;
|
||||
pub mod safetensors;
|
||||
mod shape;
|
||||
mod storage;
|
||||
mod strided_index;
|
||||
|
27
candle-core/src/safetensors.rs
Normal file
27
candle-core/src/safetensors.rs
Normal file
@ -0,0 +1,27 @@
|
||||
use crate::{Device, Result, Tensor};
|
||||
use half::f16;
|
||||
use safetensors::tensor as st;
|
||||
|
||||
pub fn convert(view: st::TensorView<'_>, device: &Device) -> Result<Tensor> {
|
||||
match view.dtype() {
|
||||
st::Dtype::F16 => {
|
||||
let v = view.data();
|
||||
if (v.as_ptr() as usize) % 2 == 0 {
|
||||
// SAFETY This is safe because we just checked that this
|
||||
// was correctly aligned.
|
||||
let data: &[f16] =
|
||||
unsafe { std::slice::from_raw_parts(v.as_ptr() as *const f16, v.len() / 2) };
|
||||
Tensor::from_slice(data, view.shape(), device)
|
||||
} else {
|
||||
let mut c = Vec::with_capacity(v.len() / 2);
|
||||
let mut i = 0;
|
||||
while i < v.len() {
|
||||
c.push(f16::from_le_bytes([v[i], v[i + 1]]));
|
||||
i += 2;
|
||||
}
|
||||
Tensor::from_slice(&c, view.shape(), device)
|
||||
}
|
||||
}
|
||||
dt => todo!("Unhandled dtype {dt:?}"),
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user