Read all the tensors in a PyTorch pth file. (#1106)

This commit is contained in:
Laurent Mazare
2023-10-16 13:50:07 +01:00
committed by GitHub
parent 588ad4835a
commit 0106b0b04c

View File

@ -723,3 +723,16 @@ impl PthTensors {
Ok(Some(tensor))
}
}
/// Read all the tensors from a PyTorch pth file.
pub fn read_all<P: AsRef<std::path::Path>>(path: P) -> Result<Vec<(String, Tensor)>> {
let pth = PthTensors::new(path)?;
let tensor_names = pth.tensor_infos.keys();
let mut tensors = Vec::with_capacity(tensor_names.len());
for name in tensor_names {
if let Some(tensor) = pth.get(name)? {
tensors.push((name.to_string(), tensor))
}
}
Ok(tensors)
}