mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Adding new surface for savetensors (global load, global save).
This commit is contained in:
@ -79,7 +79,7 @@ Cheatsheet:
|
|||||||
| Arithmetic | `a + b` | `&a + &b` |
|
| Arithmetic | `a + b` | `&a + &b` |
|
||||||
| Device | `tensor.to(device="cuda")` | `tensor.to_device(&Device::Cuda(0))?` |
|
| Device | `tensor.to(device="cuda")` | `tensor.to_device(&Device::Cuda(0))?` |
|
||||||
| Dtype | `tensor.to(dtype=torch.float16)` | `tensor.to_dtype(&DType::F16)?` |
|
| Dtype | `tensor.to(dtype=torch.float16)` | `tensor.to_dtype(&DType::F16)?` |
|
||||||
| Saving | `torch.save({"A": A}, "model.bin")` | `safetensors::serialize_to_file(HashMap::from([("A", A)]), "model.safetensors")?` |
|
| Saving | `torch.save({"A": A}, "model.bin")` | `candle::safetensors::save(HashMap::from([("A", A)]), "model.safetensors")?` |
|
||||||
| Loading | `weights = torch.load("model.bin")` | `candle::safetensors::load("model.safetensors", &device)` |
|
| Loading | `weights = torch.load("model.bin")` | `candle::safetensors::load("model.safetensors", &device)` |
|
||||||
|
|
||||||
<!--- ANCHOR_END: cheatsheet --->
|
<!--- ANCHOR_END: cheatsheet --->
|
||||||
|
@ -249,6 +249,10 @@ pub fn load<P: AsRef<Path>>(filename: P, device: &Device) -> Result<HashMap<Stri
|
|||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn save<P: AsRef<Path>>(tensors: HashMap<&str, Tensor>, filename: P) -> Result<()> {
|
||||||
|
Ok(st::serialize_to_file(tensors, &None, filename.as_ref())?)
|
||||||
|
}
|
||||||
|
|
||||||
pub struct MmapedFile(memmap2::Mmap);
|
pub struct MmapedFile(memmap2::Mmap);
|
||||||
|
|
||||||
impl MmapedFile {
|
impl MmapedFile {
|
||||||
@ -289,7 +293,7 @@ mod tests {
|
|||||||
let t = Tensor::zeros((2, 2), DType::F32, &Device::Cpu).unwrap();
|
let t = Tensor::zeros((2, 2), DType::F32, &Device::Cpu).unwrap();
|
||||||
let u = Tensor::zeros((1, 2), DType::F32, &Device::Cpu).unwrap();
|
let u = Tensor::zeros((1, 2), DType::F32, &Device::Cpu).unwrap();
|
||||||
let map: HashMap<_, _> = [("t", t), ("u", u)].into_iter().collect();
|
let map: HashMap<_, _> = [("t", t), ("u", u)].into_iter().collect();
|
||||||
st::serialize_to_file(map, &None, std::path::Path::new("multi.safetensors")).unwrap();
|
save(map, "multi.safetensors").unwrap();
|
||||||
|
|
||||||
let weights = load("multi.safetensors", &Device::Cpu).unwrap();
|
let weights = load("multi.safetensors", &Device::Cpu).unwrap();
|
||||||
assert_eq!(weights.get("t").unwrap().dims(), &[2, 2]);
|
assert_eq!(weights.get("t").unwrap().dims(), &[2, 2]);
|
||||||
|
Reference in New Issue
Block a user