Modifying safetensors export to get simple load and save.

This commit is contained in:
Nicolas Patry
2023-07-27 15:20:13 +02:00
parent 836ba3e090
commit 310094310b
2 changed files with 19 additions and 3 deletions

View File

@ -2,6 +2,8 @@ use crate::{DType, Device, Error, Result, Tensor, WithDType};
use safetensors::tensor as st;
use safetensors::tensor::SafeTensors;
use std::borrow::Cow;
use std::collections::HashMap;
use std::path::Path;
impl From<DType> for st::Dtype {
fn from(value: DType) -> Self {
@ -238,6 +240,15 @@ fn convert_back(tensor: &Tensor) -> Result<Vec<u8>> {
}
}
pub fn load<P: AsRef<Path>>(filename: P, device: &Device) -> Result<HashMap<String, Tensor>> {
let data = std::fs::read(filename.as_ref())?;
let st = safetensors::SafeTensors::deserialize(&data)?;
st.tensors()
.into_iter()
.map(|(name, view)| Ok((name, view.load(device)?)))
.collect()
}
pub struct MmapedFile(memmap2::Mmap);
impl MmapedFile {
@ -274,11 +285,15 @@ mod tests {
}
#[test]
fn save_multiple_tensors() {
fn save_load_multiple_tensors() {
let t = Tensor::zeros((2, 2), DType::F32, &Device::Cpu).unwrap();
let u = Tensor::zeros((1, 2), DType::F32, &Device::Cpu).unwrap();
let map: HashMap<_, _> = [("t", t), ("u", u)].into_iter().collect();
st::serialize_to_file(map, &None, std::path::Path::new("multi.safetensors")).unwrap();
let weights = load("multi.safetensors", &Device::Cpu).unwrap();
assert_eq!(weights.get("t").unwrap().dims(), &[2, 2]);
assert_eq!(weights.get("u").unwrap().dims(), &[1, 2]);
let bytes = std::fs::read("multi.safetensors").unwrap();
assert_eq!(bytes, b"x\0\0\0\0\0\0\0{\"t\":{\"dtype\":\"F32\",\"shape\":[2,2],\"data_offsets\":[0,16]},\"u\":{\"dtype\":\"F32\",\"shape\":[1,2],\"data_offsets\":[16,24]}} \0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0");
std::fs::remove_file("multi.safetensors").unwrap();