From a90fc5ca5a486e988d39ea69ee3d3bb40a39c017 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Tue, 6 Feb 2024 15:26:11 +0100 Subject: [PATCH] Add `VarBuilder::from_backend` (#1670) `candle-nn` already exposes a trait to define custom backends. However, it's not possible to actually construct a `VarBuilder` with a custom backend because the constructor is not exposed. This change makes the constructor public and renames it from `new` to `from_backend` to avoid that it is seen as the primary constructor (which could be confusing to users). --- candle-nn/src/var_builder.rs | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs index 83c86a6f..33d94c83 100644 --- a/candle-nn/src/var_builder.rs +++ b/candle-nn/src/var_builder.rs @@ -412,7 +412,16 @@ impl SimpleBackend for candle::safetensors::BufferedSafetensors { } impl<'a> VarBuilder<'a> { - fn new(backend: Box, dtype: DType, device: Device) -> Self { + /// Initializes a `VarBuilder` using a custom backend. + /// + /// It is preferred to use one of the more specific constructors. This + /// constructor is provided to allow downstream users to define their own + /// backends. + pub fn from_backend( + backend: Box, + dtype: DType, + device: Device, + ) -> Self { let data = TensorData { backend, dtype, @@ -427,13 +436,13 @@ impl<'a> VarBuilder<'a> { /// Initializes a `VarBuilder` that uses zeros for any tensor. pub fn zeros(dtype: DType, dev: &Device) -> Self { - Self::new(Box::new(Zeros), dtype, dev.clone()) + Self::from_backend(Box::new(Zeros), dtype, dev.clone()) } /// Initializes a `VarBuilder` that retrieves tensors stored in a hashtable. An error is /// returned if no tensor is available under the requested path or on shape mismatches. pub fn from_tensors(ts: HashMap, dtype: DType, dev: &Device) -> Self { - Self::new(Box::new(ts), dtype, dev.clone()) + Self::from_backend(Box::new(ts), dtype, dev.clone()) } /// Initializes a `VarBuilder` using a `VarMap`. The requested tensors are created and @@ -443,7 +452,7 @@ impl<'a> VarBuilder<'a> { /// Note that it is possible to load the tensor values after model creation using the `load` /// method on `varmap`, this can be used to start model training from an existing checkpoint. pub fn from_varmap(varmap: &VarMap, dtype: DType, dev: &Device) -> Self { - Self::new(Box::new(varmap.clone()), dtype, dev.clone()) + Self::from_backend(Box::new(varmap.clone()), dtype, dev.clone()) } /// Initializes a `VarBuilder` that retrieves tensors stored in a collection of safetensors @@ -458,25 +467,25 @@ impl<'a> VarBuilder<'a> { dev: &Device, ) -> Result { let tensors = candle::safetensors::MmapedSafetensors::multi(paths)?; - Ok(Self::new(Box::new(tensors), dtype, dev.clone())) + Ok(Self::from_backend(Box::new(tensors), dtype, dev.clone())) } /// Initializes a `VarBuilder` from a binary builder in the safetensor format. pub fn from_buffered_safetensors(data: Vec, dtype: DType, dev: &Device) -> Result { let tensors = candle::safetensors::BufferedSafetensors::new(data)?; - Ok(Self::new(Box::new(tensors), dtype, dev.clone())) + Ok(Self::from_backend(Box::new(tensors), dtype, dev.clone())) } /// Initializes a `VarBuilder` that retrieves tensors stored in a numpy npz file. pub fn from_npz>(p: P, dtype: DType, dev: &Device) -> Result { let npz = candle::npy::NpzTensors::new(p)?; - Ok(Self::new(Box::new(npz), dtype, dev.clone())) + Ok(Self::from_backend(Box::new(npz), dtype, dev.clone())) } /// Initializes a `VarBuilder` that retrieves tensors stored in a pytorch pth file. pub fn from_pth>(p: P, dtype: DType, dev: &Device) -> Result { let pth = candle::pickle::PthTensors::new(p)?; - Ok(Self::new(Box::new(pth), dtype, dev.clone())) + Ok(Self::from_backend(Box::new(pth), dtype, dev.clone())) } }