From afb5e24a63a603653bcb0425ebdbffa1ee047ded Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 1 Aug 2023 17:19:22 +0200 Subject: [PATCH] Remove map ownership from `save`. --- README.md | 2 +- candle-core/src/safetensors.rs | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index c6f25702..5ca83d25 100644 --- a/README.md +++ b/README.md @@ -79,7 +79,7 @@ Cheatsheet: | Arithmetic | `a + b` | `&a + &b` | | Device | `tensor.to(device="cuda")` | `tensor.to_device(&Device::Cuda(0))?` | | Dtype | `tensor.to(dtype=torch.float16)` | `tensor.to_dtype(&DType::F16)?` | -| Saving | `torch.save({"A": A}, "model.bin")` | `candle::safetensors::save(HashMap::from([("A", A)]), "model.safetensors")?` | +| Saving | `torch.save({"A": A}, "model.bin")` | `candle::safetensors::save(&HashMap::from([("A", A)]), "model.safetensors")?` | | Loading | `weights = torch.load("model.bin")` | `candle::safetensors::load("model.safetensors", &device)` | diff --git a/candle-core/src/safetensors.rs b/candle-core/src/safetensors.rs index 195bbcf7..1880a041 100644 --- a/candle-core/src/safetensors.rs +++ b/candle-core/src/safetensors.rs @@ -249,7 +249,7 @@ pub fn load>(filename: P, device: &Device) -> Result>(tensors: HashMap<&str, Tensor>, filename: P) -> Result<()> { +pub fn save>(tensors: &HashMap<&str, Tensor>, filename: P) -> Result<()> { Ok(st::serialize_to_file(tensors, &None, filename.as_ref())?) } @@ -293,7 +293,7 @@ mod tests { let t = Tensor::zeros((2, 2), DType::F32, &Device::Cpu).unwrap(); let u = Tensor::zeros((1, 2), DType::F32, &Device::Cpu).unwrap(); let map: HashMap<_, _> = [("t", t), ("u", u)].into_iter().collect(); - save(map, "multi.safetensors").unwrap(); + save(&map, "multi.safetensors").unwrap(); let weights = load("multi.safetensors", &Device::Cpu).unwrap(); assert_eq!(weights.get("t").unwrap().dims(), &[2, 2]);