mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Add some doc to the varbuilder. (#700)
This commit is contained in:
@ -1,3 +1,6 @@
|
|||||||
|
//! A `VarBuilder` is used to retrieve variables used by a model. These variables can either come
|
||||||
|
//! from a pre-trained checkpoint, e.g. using `VarBuilder::from_safetensors`, or initialized for
|
||||||
|
//! training, e.g. using `VarBuilder::from_varmap`.
|
||||||
use crate::VarMap;
|
use crate::VarMap;
|
||||||
use candle::{safetensors::Load, DType, Device, Error, Result, Shape, Tensor};
|
use candle::{safetensors::Load, DType, Device, Error, Result, Shape, Tensor};
|
||||||
use safetensors::{slice::IndexOp, tensor::SafeTensors};
|
use safetensors::{slice::IndexOp, tensor::SafeTensors};
|
||||||
@ -107,6 +110,15 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> {
|
|||||||
self.path.join(".")
|
self.path.join(".")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns a new `VarBuilder` using the root path.
|
||||||
|
pub fn root(&self) -> Self {
|
||||||
|
Self {
|
||||||
|
data: self.data.clone(),
|
||||||
|
path: vec![],
|
||||||
|
_phantom: std::marker::PhantomData,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Returns a new `VarBuilder` with the prefix set to `prefix`.
|
/// Returns a new `VarBuilder` with the prefix set to `prefix`.
|
||||||
pub fn set_prefix(&self, prefix: impl ToString) -> Self {
|
pub fn set_prefix(&self, prefix: impl ToString) -> Self {
|
||||||
Self {
|
Self {
|
||||||
@ -327,18 +339,29 @@ impl<'a> VarBuilder<'a> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Initializes a `VarBuilder` that uses zeros for any tensor.
|
||||||
pub fn zeros(dtype: DType, dev: &Device) -> Self {
|
pub fn zeros(dtype: DType, dev: &Device) -> Self {
|
||||||
Self::new(Box::new(Zeros), dtype, dev.clone())
|
Self::new(Box::new(Zeros), dtype, dev.clone())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Initializes a `VarBuilder` that retrieves tensors stored in a hashtable. An error is
|
||||||
|
/// returned if no tensor is available under the requested path or on shape mismatches.
|
||||||
pub fn from_tensors(ts: HashMap<String, Tensor>, dtype: DType, dev: &Device) -> Self {
|
pub fn from_tensors(ts: HashMap<String, Tensor>, dtype: DType, dev: &Device) -> Self {
|
||||||
Self::new(Box::new(ts), dtype, dev.clone())
|
Self::new(Box::new(ts), dtype, dev.clone())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Initializes a `VarBuilder` using a `VarMap`. The requested tensors are created and
|
||||||
|
/// initialized on new paths, the same tensor is used if the same path is requested multiple
|
||||||
|
/// times. This is commonly used when initializing a model before training.
|
||||||
|
///
|
||||||
|
/// Note that it is possible to load the tensor values after model creation using the `load`
|
||||||
|
/// method on `varmap`, this can be used to start model training from an existing checkpoint.
|
||||||
pub fn from_varmap(varmap: &VarMap, dtype: DType, dev: &Device) -> Self {
|
pub fn from_varmap(varmap: &VarMap, dtype: DType, dev: &Device) -> Self {
|
||||||
Self::new(Box::new(varmap.clone()), dtype, dev.clone())
|
Self::new(Box::new(varmap.clone()), dtype, dev.clone())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Initializes a `VarBuilder` that retrieves tensors stored in a collection of safetensors
|
||||||
|
/// files.
|
||||||
pub fn from_safetensors(safetensors: Vec<SafeTensors<'a>>, dtype: DType, dev: &Device) -> Self {
|
pub fn from_safetensors(safetensors: Vec<SafeTensors<'a>>, dtype: DType, dev: &Device) -> Self {
|
||||||
let mut routing = HashMap::new();
|
let mut routing = HashMap::new();
|
||||||
for (index, sf) in safetensors.iter().enumerate() {
|
for (index, sf) in safetensors.iter().enumerate() {
|
||||||
@ -353,6 +376,7 @@ impl<'a> VarBuilder<'a> {
|
|||||||
Self::new(Box::new(tensors), dtype, dev.clone())
|
Self::new(Box::new(tensors), dtype, dev.clone())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Initializes a `VarBuilder` that retrieves tensors stored in a numpy npz file.
|
||||||
pub fn from_npz<P: AsRef<std::path::Path>>(p: P, dtype: DType, dev: &Device) -> Result<Self> {
|
pub fn from_npz<P: AsRef<std::path::Path>>(p: P, dtype: DType, dev: &Device) -> Result<Self> {
|
||||||
let npz = candle::npy::NpzTensors::new(p)?;
|
let npz = candle::npy::NpzTensors::new(p)?;
|
||||||
Ok(Self::new(Box::new(npz), dtype, dev.clone()))
|
Ok(Self::new(Box::new(npz), dtype, dev.clone()))
|
||||||
|
Reference in New Issue
Block a user