From 0106b0b04c3505a1155b3eab65ac212977c6c3dd Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Mon, 16 Oct 2023 13:50:07 +0100 Subject: [PATCH] Read all the tensors in a PyTorch pth file. (#1106) --- candle-core/src/pickle.rs | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/candle-core/src/pickle.rs b/candle-core/src/pickle.rs index 37c15018..0013113a 100644 --- a/candle-core/src/pickle.rs +++ b/candle-core/src/pickle.rs @@ -723,3 +723,16 @@ impl PthTensors { Ok(Some(tensor)) } } + +/// Read all the tensors from a PyTorch pth file. +pub fn read_all>(path: P) -> Result> { + let pth = PthTensors::new(path)?; + let tensor_names = pth.tensor_infos.keys(); + let mut tensors = Vec::with_capacity(tensor_names.len()); + for name in tensor_names { + if let Some(tensor) = pth.get(name)? { + tensors.push((name.to_string(), tensor)) + } + } + Ok(tensors) +}