Adding new surface for savetensors (global load, global save).

This commit is contained in:
Nicolas Patry
2023-07-27 15:32:42 +02:00
parent 310094310b
commit 89d1fd03e5
2 changed files with 6 additions and 2 deletions

View File

@ -249,6 +249,10 @@ pub fn load<P: AsRef<Path>>(filename: P, device: &Device) -> Result<HashMap<Stri
.collect()
}
pub fn save<P: AsRef<Path>>(tensors: HashMap<&str, Tensor>, filename: P) -> Result<()> {
Ok(st::serialize_to_file(tensors, &None, filename.as_ref())?)
}
pub struct MmapedFile(memmap2::Mmap);
impl MmapedFile {
@ -289,7 +293,7 @@ mod tests {
let t = Tensor::zeros((2, 2), DType::F32, &Device::Cpu).unwrap();
let u = Tensor::zeros((1, 2), DType::F32, &Device::Cpu).unwrap();
let map: HashMap<_, _> = [("t", t), ("u", u)].into_iter().collect();
st::serialize_to_file(map, &None, std::path::Path::new("multi.safetensors")).unwrap();
save(map, "multi.safetensors").unwrap();
let weights = load("multi.safetensors", &Device::Cpu).unwrap();
assert_eq!(weights.get("t").unwrap().dims(), &[2, 2]);