Move some safetensors bits in the candle-core crate.

This commit is contained in:
laurent
2023-07-03 08:37:46 +01:00
parent 9e419641fb
commit cf2789fb81
3 changed files with 30 additions and 31 deletions

View File

@ -1,38 +1,10 @@
use super::*;
use candle::{Device, Result, Tensor};
use half::f16;
use memmap2::MmapOptions;
use safetensors::{
tensor::{Dtype, TensorView},
SafeTensors,
};
use safetensors::SafeTensors;
use std::fs::File;
use std::path::PathBuf;
fn convert(view: TensorView<'_>, device: &Device) -> Result<Tensor> {
match view.dtype() {
Dtype::F16 => {
let v = view.data();
if (v.as_ptr() as usize) % 2 == 0 {
// SAFETY This is safe because we just checked that this
// was correctly aligned.
let data: &[f16] =
unsafe { std::slice::from_raw_parts(v.as_ptr() as *const f16, v.len() / 2) };
Tensor::from_slice(data, view.shape(), device)?.to_dtype(DTYPE)
} else {
let mut c = Vec::with_capacity(v.len() / 2);
let mut i = 0;
while i < v.len() {
c.push(f16::from_le_bytes([v[i], v[i + 1]]));
i += 2;
}
Tensor::from_slice(&c, view.shape(), device)?.to_dtype(DTYPE)
}
}
dt => todo!("Unhandled dtype {dt:?}"),
}
}
pub struct VarBuilder<'a> {
routing: HashMap<String, usize>,
safetensors: Vec<SafeTensors<'a>>,
@ -59,8 +31,7 @@ impl<'a> VarBuilder<'a> {
// Unwrap or 0 just to let the proper error flow.
let index = self.routing.get(tensor_name).unwrap_or(&0);
let view = self.safetensors[*index].tensor(tensor_name).unwrap();
let tensor = convert(view, &self.device)?;
Ok(tensor)
candle::safetensors::convert(view, &self.device)?.to_dtype(DTYPE)
}
}

View File

@ -10,6 +10,7 @@ mod error;
mod layout;
mod npy;
mod op;
pub mod safetensors;
mod shape;
mod storage;
mod strided_index;

View File

@ -0,0 +1,27 @@
use crate::{Device, Result, Tensor};
use half::f16;
use safetensors::tensor as st;
pub fn convert(view: st::TensorView<'_>, device: &Device) -> Result<Tensor> {
match view.dtype() {
st::Dtype::F16 => {
let v = view.data();
if (v.as_ptr() as usize) % 2 == 0 {
// SAFETY This is safe because we just checked that this
// was correctly aligned.
let data: &[f16] =
unsafe { std::slice::from_raw_parts(v.as_ptr() as *const f16, v.len() / 2) };
Tensor::from_slice(data, view.shape(), device)
} else {
let mut c = Vec::with_capacity(v.len() / 2);
let mut i = 0;
while i < v.len() {
c.push(f16::from_le_bytes([v[i], v[i + 1]]));
i += 2;
}
Tensor::from_slice(&c, view.shape(), device)
}
}
dt => todo!("Unhandled dtype {dt:?}"),
}
}