diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs index 03929681..bf5d5b43 100644 --- a/candle-nn/src/var_builder.rs +++ b/candle-nn/src/var_builder.rs @@ -1,3 +1,6 @@ +//! A `VarBuilder` is used to retrieve variables used by a model. These variables can either come +//! from a pre-trained checkpoint, e.g. using `VarBuilder::from_safetensors`, or initialized for +//! training, e.g. using `VarBuilder::from_varmap`. use crate::VarMap; use candle::{safetensors::Load, DType, Device, Error, Result, Shape, Tensor}; use safetensors::{slice::IndexOp, tensor::SafeTensors}; @@ -107,6 +110,15 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> { self.path.join(".") } + /// Returns a new `VarBuilder` using the root path. + pub fn root(&self) -> Self { + Self { + data: self.data.clone(), + path: vec![], + _phantom: std::marker::PhantomData, + } + } + /// Returns a new `VarBuilder` with the prefix set to `prefix`. pub fn set_prefix(&self, prefix: impl ToString) -> Self { Self { @@ -327,18 +339,29 @@ impl<'a> VarBuilder<'a> { } } + /// Initializes a `VarBuilder` that uses zeros for any tensor. pub fn zeros(dtype: DType, dev: &Device) -> Self { Self::new(Box::new(Zeros), dtype, dev.clone()) } + /// Initializes a `VarBuilder` that retrieves tensors stored in a hashtable. An error is + /// returned if no tensor is available under the requested path or on shape mismatches. pub fn from_tensors(ts: HashMap, dtype: DType, dev: &Device) -> Self { Self::new(Box::new(ts), dtype, dev.clone()) } + /// Initializes a `VarBuilder` using a `VarMap`. The requested tensors are created and + /// initialized on new paths, the same tensor is used if the same path is requested multiple + /// times. This is commonly used when initializing a model before training. + /// + /// Note that it is possible to load the tensor values after model creation using the `load` + /// method on `varmap`, this can be used to start model training from an existing checkpoint. pub fn from_varmap(varmap: &VarMap, dtype: DType, dev: &Device) -> Self { Self::new(Box::new(varmap.clone()), dtype, dev.clone()) } + /// Initializes a `VarBuilder` that retrieves tensors stored in a collection of safetensors + /// files. pub fn from_safetensors(safetensors: Vec>, dtype: DType, dev: &Device) -> Self { let mut routing = HashMap::new(); for (index, sf) in safetensors.iter().enumerate() { @@ -353,6 +376,7 @@ impl<'a> VarBuilder<'a> { Self::new(Box::new(tensors), dtype, dev.clone()) } + /// Initializes a `VarBuilder` that retrieves tensors stored in a numpy npz file. pub fn from_npz>(p: P, dtype: DType, dev: &Device) -> Result { let npz = candle::npy::NpzTensors::new(p)?; Ok(Self::new(Box::new(npz), dtype, dev.clone()))