From 310094310b6cd6b3a0b2b8f02659e73db3496855 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 27 Jul 2023 15:20:13 +0200 Subject: [PATCH] Modifying `safetensors` export to get simple load and save. --- README.md | 5 +++-- candle-core/src/safetensors.rs | 17 ++++++++++++++++- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index b6a30c17..6ea18b10 100644 --- a/README.md +++ b/README.md @@ -72,14 +72,15 @@ Cheatsheet: | | Using PyTorch | Using Candle | |------------|------------------------------------------|------------------------------------------------------------------| | Creation | `torch.Tensor([[1, 2], [3, 4]])` | `Tensor::new(&[[1f32, 2.]], [3., 4.]], &Device::Cpu)?` | +| Creation | `torch.zeros((2, 2))` | `Tensor::zeros((2, 2), DType::F32, &Device::Cpu)?` | | Indexing | `tensor[:, :4]` | `tensor.i((.., ..4))?` | | Operations | `tensor.view((2, 2))` | `tensor.reshape((2, 2))?` | | Operations | `a.matmul(b)` | `a.matmul(&b)?` | | Arithmetic | `a + b` | `&a + &b` | | Device | `tensor.to(device="cuda")` | `tensor.to_device(&Device::Cuda(0))?` | | Dtype | `tensor.to(dtype=torch.float16)` | `tensor.to_dtype(&DType::F16)?` | -| Saving | `torch.save({"A": A}, "model.bin")` | `tensor.save_safetensors("A", "model.safetensors")?` | -| Loading | `weights = torch.load("model.bin")` | TODO (see the examples for now) | +| Saving | `torch.save({"A": A}, "model.bin")` | `safetensors::serialize_to_file(HashMap::from([("A", A)]), "model.safetensors")?` | +| Loading | `weights = torch.load("model.bin")` | `candle::safetensors::load("model.safetensors", &device)` | diff --git a/candle-core/src/safetensors.rs b/candle-core/src/safetensors.rs index 06b9b23b..8ebc68a6 100644 --- a/candle-core/src/safetensors.rs +++ b/candle-core/src/safetensors.rs @@ -2,6 +2,8 @@ use crate::{DType, Device, Error, Result, Tensor, WithDType}; use safetensors::tensor as st; use safetensors::tensor::SafeTensors; use std::borrow::Cow; +use std::collections::HashMap; +use std::path::Path; impl From for st::Dtype { fn from(value: DType) -> Self { @@ -238,6 +240,15 @@ fn convert_back(tensor: &Tensor) -> Result> { } } +pub fn load>(filename: P, device: &Device) -> Result> { + let data = std::fs::read(filename.as_ref())?; + let st = safetensors::SafeTensors::deserialize(&data)?; + st.tensors() + .into_iter() + .map(|(name, view)| Ok((name, view.load(device)?))) + .collect() +} + pub struct MmapedFile(memmap2::Mmap); impl MmapedFile { @@ -274,11 +285,15 @@ mod tests { } #[test] - fn save_multiple_tensors() { + fn save_load_multiple_tensors() { let t = Tensor::zeros((2, 2), DType::F32, &Device::Cpu).unwrap(); let u = Tensor::zeros((1, 2), DType::F32, &Device::Cpu).unwrap(); let map: HashMap<_, _> = [("t", t), ("u", u)].into_iter().collect(); st::serialize_to_file(map, &None, std::path::Path::new("multi.safetensors")).unwrap(); + + let weights = load("multi.safetensors", &Device::Cpu).unwrap(); + assert_eq!(weights.get("t").unwrap().dims(), &[2, 2]); + assert_eq!(weights.get("u").unwrap().dims(), &[1, 2]); let bytes = std::fs::read("multi.safetensors").unwrap(); assert_eq!(bytes, b"x\0\0\0\0\0\0\0{\"t\":{\"dtype\":\"F32\",\"shape\":[2,2],\"data_offsets\":[0,16]},\"u\":{\"dtype\":\"F32\",\"shape\":[1,2],\"data_offsets\":[16,24]}} \0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0"); std::fs::remove_file("multi.safetensors").unwrap();