mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Provide a method to allow PTH files with state maps to be loaded. (#2639)
* Provide a method to allow PTH files iwth state maps to be loaded. * add a line to the doc * String-. &str
This commit is contained in:
@ -544,7 +544,17 @@ impl<'a> VarBuilder<'a> {
|
||||
let pth = candle::pickle::PthTensors::new(p, None)?;
|
||||
Ok(Self::from_backend(Box::new(pth), dtype, dev.clone()))
|
||||
}
|
||||
|
||||
/// Initializes a `VarBuilder` that retrieves tensors stored in a pytorch pth file.
|
||||
/// similar to [`from_pth`] but requires a `state_key`.
|
||||
pub fn from_pth_with_state<P: AsRef<std::path::Path>>(
|
||||
p: P,
|
||||
dtype: DType,
|
||||
state_key: &str,
|
||||
dev: &Device,
|
||||
) -> Result<Self> {
|
||||
let pth = candle::pickle::PthTensors::new(p, Some(state_key))?;
|
||||
Ok(Self::from_backend(Box::new(pth), dtype, dev.clone()))
|
||||
}
|
||||
/// Gets a VarBuilder that applies some renaming function on tensor it gets queried for before
|
||||
/// passing the new names to the inner VarBuilder.
|
||||
///
|
||||
|
Reference in New Issue
Block a user