fix: can directly save the loaded weights (#421)

This commit is contained in:
Yumin Wu
2023-08-12 23:33:29 +08:00
committed by GitHub
parent 0741ebbd51
commit 16b89f5b83

View File

@ -253,7 +253,10 @@ pub fn load_buffer(data: &[u8], device: &Device) -> Result<HashMap<String, Tenso
.collect() .collect()
} }
pub fn save<P: AsRef<Path>>(tensors: &HashMap<&str, Tensor>, filename: P) -> Result<()> { pub fn save<K: AsRef<str> + Ord + std::fmt::Display, P: AsRef<Path>>(
tensors: &HashMap<K, Tensor>,
filename: P,
) -> Result<()> {
Ok(st::serialize_to_file(tensors, &None, filename.as_ref())?) Ok(st::serialize_to_file(tensors, &None, filename.as_ref())?)
} }