Enhance pickle to retrieve state_dict with a given key (#1671)

This commit is contained in:
Dilshod Tadjibaev
2024-02-06 14:17:33 -06:00
committed by GitHub
parent a90fc5ca5a
commit b75e8945bc
6 changed files with 61 additions and 9 deletions

View File

@ -1,6 +1,14 @@
/// Regression test for pth files not loading on Windows.
#[test]
fn test_pth() {
let tensors = candle_core::pickle::PthTensors::new("tests/test.pt").unwrap();
let tensors = candle_core::pickle::PthTensors::new("tests/test.pt", None).unwrap();
tensors.get("test").unwrap().unwrap();
}
#[test]
fn test_pth_with_key() {
let tensors =
candle_core::pickle::PthTensors::new("tests/test_with_key.pt", Some("model_state_dict"))
.unwrap();
tensors.get("test").unwrap().unwrap();
}