Enhance pickle to retrieve state_dict with a given key (#1671)

This commit is contained in:
Dilshod Tadjibaev
2024-02-06 14:17:33 -06:00
committed by GitHub
parent a90fc5ca5a
commit b75e8945bc
6 changed files with 61 additions and 9 deletions

View File

@ -484,7 +484,7 @@ impl<'a> VarBuilder<'a> {
/// Initializes a `VarBuilder` that retrieves tensors stored in a pytorch pth file.
pub fn from_pth<P: AsRef<std::path::Path>>(p: P, dtype: DType, dev: &Device) -> Result<Self> {
let pth = candle::pickle::PthTensors::new(p)?;
let pth = candle::pickle::PthTensors::new(p, None)?;
Ok(Self::from_backend(Box::new(pth), dtype, dev.clone()))
}
}