Provide a method to allow PTH files with state maps to be loaded. (#2639)

* Provide a method to allow PTH files iwth state maps to be loaded.

* add a line to the doc

* String-. &str
This commit is contained in:
zachcp
2024-11-26 16:52:53 -05:00
committed by GitHub
parent c12db594e3
commit b4deb5c5a9

View File

@ -544,7 +544,17 @@ impl<'a> VarBuilder<'a> {
let pth = candle::pickle::PthTensors::new(p, None)?;
Ok(Self::from_backend(Box::new(pth), dtype, dev.clone()))
}
/// Initializes a `VarBuilder` that retrieves tensors stored in a pytorch pth file.
/// similar to [`from_pth`] but requires a `state_key`.
pub fn from_pth_with_state<P: AsRef<std::path::Path>>(
p: P,
dtype: DType,
state_key: &str,
dev: &Device,
) -> Result<Self> {
let pth = candle::pickle::PthTensors::new(p, Some(state_key))?;
Ok(Self::from_backend(Box::new(pth), dtype, dev.clone()))
}
/// Gets a VarBuilder that applies some renaming function on tensor it gets queried for before
/// passing the new names to the inner VarBuilder.
///