mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Provide a method to allow PTH files with state maps to be loaded. (#2639)
* Provide a method to allow PTH files iwth state maps to be loaded. * add a line to the doc * String-. &str
This commit is contained in:
@ -544,7 +544,17 @@ impl<'a> VarBuilder<'a> {
|
|||||||
let pth = candle::pickle::PthTensors::new(p, None)?;
|
let pth = candle::pickle::PthTensors::new(p, None)?;
|
||||||
Ok(Self::from_backend(Box::new(pth), dtype, dev.clone()))
|
Ok(Self::from_backend(Box::new(pth), dtype, dev.clone()))
|
||||||
}
|
}
|
||||||
|
/// Initializes a `VarBuilder` that retrieves tensors stored in a pytorch pth file.
|
||||||
|
/// similar to [`from_pth`] but requires a `state_key`.
|
||||||
|
pub fn from_pth_with_state<P: AsRef<std::path::Path>>(
|
||||||
|
p: P,
|
||||||
|
dtype: DType,
|
||||||
|
state_key: &str,
|
||||||
|
dev: &Device,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let pth = candle::pickle::PthTensors::new(p, Some(state_key))?;
|
||||||
|
Ok(Self::from_backend(Box::new(pth), dtype, dev.clone()))
|
||||||
|
}
|
||||||
/// Gets a VarBuilder that applies some renaming function on tensor it gets queried for before
|
/// Gets a VarBuilder that applies some renaming function on tensor it gets queried for before
|
||||||
/// passing the new names to the inner VarBuilder.
|
/// passing the new names to the inner VarBuilder.
|
||||||
///
|
///
|
||||||
|
Reference in New Issue
Block a user