From b75e8945bc7c67106be6288b9f357efa8068e62e Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Tue, 6 Feb 2024 14:17:33 -0600 Subject: [PATCH] Enhance pickle to retrieve state_dict with a given key (#1671) --- candle-core/examples/tensor-tools.rs | 2 +- candle-core/src/pickle.rs | 54 ++++++++++++++++++++++++--- candle-core/tests/pth.py | 2 + candle-core/tests/pth_tests.rs | 10 ++++- candle-core/tests/test_with_key.pt | Bin 0 -> 1338 bytes candle-nn/src/var_builder.rs | 2 +- 6 files changed, 61 insertions(+), 9 deletions(-) create mode 100644 candle-core/tests/test_with_key.pt diff --git a/candle-core/examples/tensor-tools.rs b/candle-core/examples/tensor-tools.rs index eb6ceb1c..1801ac58 100644 --- a/candle-core/examples/tensor-tools.rs +++ b/candle-core/examples/tensor-tools.rs @@ -196,7 +196,7 @@ fn run_ls( } } Format::Pth => { - let mut tensors = candle_core::pickle::read_pth_tensor_info(file, verbose)?; + let mut tensors = candle_core::pickle::read_pth_tensor_info(file, verbose, None)?; tensors.sort_by(|a, b| a.name.cmp(&b.name)); for tensor_info in tensors.iter() { println!( diff --git a/candle-core/src/pickle.rs b/candle-core/src/pickle.rs index 4c76c416..2c189131 100644 --- a/candle-core/src/pickle.rs +++ b/candle-core/src/pickle.rs @@ -625,9 +625,16 @@ pub struct TensorInfo { pub storage_size: usize, } +/// Read the tensor info from a .pth file. +/// +/// # Arguments +/// * `file` - The path to the .pth file. +/// * `verbose` - Whether to print debug information. +/// * `key` - Optional key to retrieve `state_dict` from the pth file. pub fn read_pth_tensor_info>( file: P, verbose: bool, + key: Option<&str>, ) -> Result> { let file = std::fs::File::open(file)?; let zip_reader = std::io::BufReader::new(file); @@ -649,8 +656,9 @@ pub fn read_pth_tensor_info>( stack.read_loop(&mut reader)?; let obj = stack.finalize()?; if VERBOSE || verbose { - println!("{obj:?}"); + println!("{obj:#?}"); } + let obj = match obj { Object::Build { callable, args } => match *callable { Object::Reduce { callable, args: _ } => match *callable { @@ -664,6 +672,24 @@ pub fn read_pth_tensor_info>( }, obj => obj, }; + + // If key is provided, then we need to extract the state_dict from the object. + let obj = if let Some(key) = key { + if let Object::Dict(key_values) = obj { + key_values + .into_iter() + .find(|(k, _)| *k == Object::Unicode(key.to_owned())) + .map(|(_, v)| v) + .ok_or_else(|| E::Msg(format!("key {key} not found")))? + } else { + obj + } + } else { + obj + }; + + // If the object is a dict, then we can extract the tensor info from it. + // NOTE: We are assuming that the `obj` is state_dict by this stage. if let Object::Dict(key_values) = obj { for (name, value) in key_values.into_iter() { match value.into_tensor_info(name, &dir_name) { @@ -686,8 +712,8 @@ pub struct PthTensors { } impl PthTensors { - pub fn new>(path: P) -> Result { - let tensor_infos = read_pth_tensor_info(path.as_ref(), false)?; + pub fn new>(path: P, key: Option<&str>) -> Result { + let tensor_infos = read_pth_tensor_info(path.as_ref(), false, key)?; let tensor_infos = tensor_infos .into_iter() .map(|ti| (ti.name.to_string(), ti)) @@ -735,9 +761,17 @@ impl PthTensors { } } -/// Read all the tensors from a PyTorch pth file. -pub fn read_all>(path: P) -> Result> { - let pth = PthTensors::new(path)?; +/// Read all the tensors from a PyTorch pth file with a given key. +/// +/// # Arguments +/// * `path` - Path to the pth file. +/// * `key` - Optional key to retrieve `state_dict` from the pth file. Sometimes the pth file +/// contains multiple objects and the state_dict is the one we are interested in. +pub fn read_all_with_key>( + path: P, + key: Option<&str>, +) -> Result> { + let pth = PthTensors::new(path, key)?; let tensor_names = pth.tensor_infos.keys(); let mut tensors = Vec::with_capacity(tensor_names.len()); for name in tensor_names { @@ -747,3 +781,11 @@ pub fn read_all>(path: P) -> Result>(path: P) -> Result> { + read_all_with_key(path, None) +} diff --git a/candle-core/tests/pth.py b/candle-core/tests/pth.py index 97724712..cab94f2c 100644 --- a/candle-core/tests/pth.py +++ b/candle-core/tests/pth.py @@ -6,3 +6,5 @@ a= torch.tensor([[1,2,3,4], [5,6,7,8]]) o = OrderedDict() o["test"] = a torch.save(o, "test.pt") + +torch.save({"model_state_dict": o}, "test_with_key.pt") diff --git a/candle-core/tests/pth_tests.rs b/candle-core/tests/pth_tests.rs index b09d1026..ad788ed9 100644 --- a/candle-core/tests/pth_tests.rs +++ b/candle-core/tests/pth_tests.rs @@ -1,6 +1,14 @@ /// Regression test for pth files not loading on Windows. #[test] fn test_pth() { - let tensors = candle_core::pickle::PthTensors::new("tests/test.pt").unwrap(); + let tensors = candle_core::pickle::PthTensors::new("tests/test.pt", None).unwrap(); + tensors.get("test").unwrap().unwrap(); +} + +#[test] +fn test_pth_with_key() { + let tensors = + candle_core::pickle::PthTensors::new("tests/test_with_key.pt", Some("model_state_dict")) + .unwrap(); tensors.get("test").unwrap().unwrap(); } diff --git a/candle-core/tests/test_with_key.pt b/candle-core/tests/test_with_key.pt new file mode 100644 index 0000000000000000000000000000000000000000..a598e02c42ea4daf5450f7888dada9b2874471e0 GIT binary patch literal 1338 zcmbVLOK;Oa5Z=6vDU?H-iVLS6f>csFZQ}Y+30Wu-S|m&nlEnez+N;E>iM#fu4OV8_A{%=Ua&@QUk(LBL$$hm%--75a)d;S;$UxT1w^n387>AHq2KYo>>4lqPFX(eg-Z1eK!Yr}cbF5km4 zMqWW11%Nh;xri}WZ5bp#%(Vi@y9JO`dS!(CFxCqQCOo-FBa=WGBHyXa0>t+Ph|6ba t(q?lj%e+;)Gsl?ec^02%3viYRAMz53|Hu{qk4@UZrG#B*5WJt>{S7%wAvgd4 literal 0 HcmV?d00001 diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs index 33d94c83..bf090219 100644 --- a/candle-nn/src/var_builder.rs +++ b/candle-nn/src/var_builder.rs @@ -484,7 +484,7 @@ impl<'a> VarBuilder<'a> { /// Initializes a `VarBuilder` that retrieves tensors stored in a pytorch pth file. pub fn from_pth>(p: P, dtype: DType, dev: &Device) -> Result { - let pth = candle::pickle::PthTensors::new(p)?; + let pth = candle::pickle::PthTensors::new(p, None)?; Ok(Self::from_backend(Box::new(pth), dtype, dev.clone())) } }