mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Modifying safetensors
export to get simple load and save.
This commit is contained in:
@ -72,14 +72,15 @@ Cheatsheet:
|
||||
| | Using PyTorch | Using Candle |
|
||||
|------------|------------------------------------------|------------------------------------------------------------------|
|
||||
| Creation | `torch.Tensor([[1, 2], [3, 4]])` | `Tensor::new(&[[1f32, 2.]], [3., 4.]], &Device::Cpu)?` |
|
||||
| Creation | `torch.zeros((2, 2))` | `Tensor::zeros((2, 2), DType::F32, &Device::Cpu)?` |
|
||||
| Indexing | `tensor[:, :4]` | `tensor.i((.., ..4))?` |
|
||||
| Operations | `tensor.view((2, 2))` | `tensor.reshape((2, 2))?` |
|
||||
| Operations | `a.matmul(b)` | `a.matmul(&b)?` |
|
||||
| Arithmetic | `a + b` | `&a + &b` |
|
||||
| Device | `tensor.to(device="cuda")` | `tensor.to_device(&Device::Cuda(0))?` |
|
||||
| Dtype | `tensor.to(dtype=torch.float16)` | `tensor.to_dtype(&DType::F16)?` |
|
||||
| Saving | `torch.save({"A": A}, "model.bin")` | `tensor.save_safetensors("A", "model.safetensors")?` |
|
||||
| Loading | `weights = torch.load("model.bin")` | TODO (see the examples for now) |
|
||||
| Saving | `torch.save({"A": A}, "model.bin")` | `safetensors::serialize_to_file(HashMap::from([("A", A)]), "model.safetensors")?` |
|
||||
| Loading | `weights = torch.load("model.bin")` | `candle::safetensors::load("model.safetensors", &device)` |
|
||||
|
||||
<!--- ANCHOR_END: cheatsheet --->
|
||||
|
||||
|
@ -2,6 +2,8 @@ use crate::{DType, Device, Error, Result, Tensor, WithDType};
|
||||
use safetensors::tensor as st;
|
||||
use safetensors::tensor::SafeTensors;
|
||||
use std::borrow::Cow;
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
|
||||
impl From<DType> for st::Dtype {
|
||||
fn from(value: DType) -> Self {
|
||||
@ -238,6 +240,15 @@ fn convert_back(tensor: &Tensor) -> Result<Vec<u8>> {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn load<P: AsRef<Path>>(filename: P, device: &Device) -> Result<HashMap<String, Tensor>> {
|
||||
let data = std::fs::read(filename.as_ref())?;
|
||||
let st = safetensors::SafeTensors::deserialize(&data)?;
|
||||
st.tensors()
|
||||
.into_iter()
|
||||
.map(|(name, view)| Ok((name, view.load(device)?)))
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub struct MmapedFile(memmap2::Mmap);
|
||||
|
||||
impl MmapedFile {
|
||||
@ -274,11 +285,15 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn save_multiple_tensors() {
|
||||
fn save_load_multiple_tensors() {
|
||||
let t = Tensor::zeros((2, 2), DType::F32, &Device::Cpu).unwrap();
|
||||
let u = Tensor::zeros((1, 2), DType::F32, &Device::Cpu).unwrap();
|
||||
let map: HashMap<_, _> = [("t", t), ("u", u)].into_iter().collect();
|
||||
st::serialize_to_file(map, &None, std::path::Path::new("multi.safetensors")).unwrap();
|
||||
|
||||
let weights = load("multi.safetensors", &Device::Cpu).unwrap();
|
||||
assert_eq!(weights.get("t").unwrap().dims(), &[2, 2]);
|
||||
assert_eq!(weights.get("u").unwrap().dims(), &[1, 2]);
|
||||
let bytes = std::fs::read("multi.safetensors").unwrap();
|
||||
assert_eq!(bytes, b"x\0\0\0\0\0\0\0{\"t\":{\"dtype\":\"F32\",\"shape\":[2,2],\"data_offsets\":[0,16]},\"u\":{\"dtype\":\"F32\",\"shape\":[1,2],\"data_offsets\":[16,24]}} \0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0");
|
||||
std::fs::remove_file("multi.safetensors").unwrap();
|
||||
|
Reference in New Issue
Block a user