mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Add VarBuilder::from_backend
(#1670)
`candle-nn` already exposes a trait to define custom backends. However, it's not possible to actually construct a `VarBuilder` with a custom backend because the constructor is not exposed. This change makes the constructor public and renames it from `new` to `from_backend` to avoid that it is seen as the primary constructor (which could be confusing to users).
This commit is contained in:
@ -412,7 +412,16 @@ impl SimpleBackend for candle::safetensors::BufferedSafetensors {
|
||||
}
|
||||
|
||||
impl<'a> VarBuilder<'a> {
|
||||
fn new(backend: Box<dyn SimpleBackend + 'a>, dtype: DType, device: Device) -> Self {
|
||||
/// Initializes a `VarBuilder` using a custom backend.
|
||||
///
|
||||
/// It is preferred to use one of the more specific constructors. This
|
||||
/// constructor is provided to allow downstream users to define their own
|
||||
/// backends.
|
||||
pub fn from_backend(
|
||||
backend: Box<dyn SimpleBackend + 'a>,
|
||||
dtype: DType,
|
||||
device: Device,
|
||||
) -> Self {
|
||||
let data = TensorData {
|
||||
backend,
|
||||
dtype,
|
||||
@ -427,13 +436,13 @@ impl<'a> VarBuilder<'a> {
|
||||
|
||||
/// Initializes a `VarBuilder` that uses zeros for any tensor.
|
||||
pub fn zeros(dtype: DType, dev: &Device) -> Self {
|
||||
Self::new(Box::new(Zeros), dtype, dev.clone())
|
||||
Self::from_backend(Box::new(Zeros), dtype, dev.clone())
|
||||
}
|
||||
|
||||
/// Initializes a `VarBuilder` that retrieves tensors stored in a hashtable. An error is
|
||||
/// returned if no tensor is available under the requested path or on shape mismatches.
|
||||
pub fn from_tensors(ts: HashMap<String, Tensor>, dtype: DType, dev: &Device) -> Self {
|
||||
Self::new(Box::new(ts), dtype, dev.clone())
|
||||
Self::from_backend(Box::new(ts), dtype, dev.clone())
|
||||
}
|
||||
|
||||
/// Initializes a `VarBuilder` using a `VarMap`. The requested tensors are created and
|
||||
@ -443,7 +452,7 @@ impl<'a> VarBuilder<'a> {
|
||||
/// Note that it is possible to load the tensor values after model creation using the `load`
|
||||
/// method on `varmap`, this can be used to start model training from an existing checkpoint.
|
||||
pub fn from_varmap(varmap: &VarMap, dtype: DType, dev: &Device) -> Self {
|
||||
Self::new(Box::new(varmap.clone()), dtype, dev.clone())
|
||||
Self::from_backend(Box::new(varmap.clone()), dtype, dev.clone())
|
||||
}
|
||||
|
||||
/// Initializes a `VarBuilder` that retrieves tensors stored in a collection of safetensors
|
||||
@ -458,25 +467,25 @@ impl<'a> VarBuilder<'a> {
|
||||
dev: &Device,
|
||||
) -> Result<Self> {
|
||||
let tensors = candle::safetensors::MmapedSafetensors::multi(paths)?;
|
||||
Ok(Self::new(Box::new(tensors), dtype, dev.clone()))
|
||||
Ok(Self::from_backend(Box::new(tensors), dtype, dev.clone()))
|
||||
}
|
||||
|
||||
/// Initializes a `VarBuilder` from a binary builder in the safetensor format.
|
||||
pub fn from_buffered_safetensors(data: Vec<u8>, dtype: DType, dev: &Device) -> Result<Self> {
|
||||
let tensors = candle::safetensors::BufferedSafetensors::new(data)?;
|
||||
Ok(Self::new(Box::new(tensors), dtype, dev.clone()))
|
||||
Ok(Self::from_backend(Box::new(tensors), dtype, dev.clone()))
|
||||
}
|
||||
|
||||
/// Initializes a `VarBuilder` that retrieves tensors stored in a numpy npz file.
|
||||
pub fn from_npz<P: AsRef<std::path::Path>>(p: P, dtype: DType, dev: &Device) -> Result<Self> {
|
||||
let npz = candle::npy::NpzTensors::new(p)?;
|
||||
Ok(Self::new(Box::new(npz), dtype, dev.clone()))
|
||||
Ok(Self::from_backend(Box::new(npz), dtype, dev.clone()))
|
||||
}
|
||||
|
||||
/// Initializes a `VarBuilder` that retrieves tensors stored in a pytorch pth file.
|
||||
pub fn from_pth<P: AsRef<std::path::Path>>(p: P, dtype: DType, dev: &Device) -> Result<Self> {
|
||||
let pth = candle::pickle::PthTensors::new(p)?;
|
||||
Ok(Self::new(Box::new(pth), dtype, dev.clone()))
|
||||
Ok(Self::from_backend(Box::new(pth), dtype, dev.clone()))
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user