mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Move the VarMap to a separate file (#525)
* Move the var-map struct in a separate file. * Fix some typos.
This commit is contained in:
@ -12,6 +12,7 @@ pub mod loss;
|
|||||||
pub mod ops;
|
pub mod ops;
|
||||||
pub mod optim;
|
pub mod optim;
|
||||||
pub mod var_builder;
|
pub mod var_builder;
|
||||||
|
pub mod var_map;
|
||||||
|
|
||||||
pub use activation::Activation;
|
pub use activation::Activation;
|
||||||
pub use batch_norm::{batch_norm, BatchNorm, BatchNormConfig};
|
pub use batch_norm::{batch_norm, BatchNorm, BatchNormConfig};
|
||||||
@ -22,7 +23,8 @@ pub use init::Init;
|
|||||||
pub use layer_norm::{layer_norm, rms_norm, LayerNorm, LayerNormConfig, RmsNorm};
|
pub use layer_norm::{layer_norm, rms_norm, LayerNorm, LayerNormConfig, RmsNorm};
|
||||||
pub use linear::{linear, linear_no_bias, Linear};
|
pub use linear::{linear, linear_no_bias, Linear};
|
||||||
pub use optim::{AdamW, ParamsAdamW, SGD};
|
pub use optim::{AdamW, ParamsAdamW, SGD};
|
||||||
pub use var_builder::{VarBuilder, VarMap};
|
pub use var_builder::VarBuilder;
|
||||||
|
pub use var_map::VarMap;
|
||||||
|
|
||||||
// A simple trait defining a module with forward method using a single argument.
|
// A simple trait defining a module with forward method using a single argument.
|
||||||
pub trait Module: std::fmt::Debug {
|
pub trait Module: std::fmt::Debug {
|
||||||
|
@ -1,91 +1,8 @@
|
|||||||
use candle::{safetensors::Load, DType, Device, Error, Result, Shape, Tensor, Var};
|
use crate::VarMap;
|
||||||
|
use candle::{safetensors::Load, DType, Device, Error, Result, Shape, Tensor};
|
||||||
use safetensors::{slice::IndexOp, tensor::SafeTensors};
|
use safetensors::{slice::IndexOp, tensor::SafeTensors};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::{Arc, Mutex};
|
use std::sync::Arc;
|
||||||
|
|
||||||
/// A `VarMap` is a store that holds named variables. Variables can be retrieved from the stores
|
|
||||||
/// and new variables can be added by providing some initialization config in case they are
|
|
||||||
/// missing.
|
|
||||||
/// `VarMap` structures can be serialized in the safetensors format.
|
|
||||||
#[derive(Clone)]
|
|
||||||
pub struct VarMap {
|
|
||||||
data: Arc<Mutex<HashMap<String, Var>>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl VarMap {
|
|
||||||
/// Create a new empty `VarMap`.
|
|
||||||
#[allow(clippy::new_without_default)]
|
|
||||||
pub fn new() -> Self {
|
|
||||||
let data = Arc::new(Mutex::new(HashMap::new()));
|
|
||||||
Self { data }
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Retrieve all the variables currently stored in the map.
|
|
||||||
pub fn all_vars(&self) -> Vec<Var> {
|
|
||||||
let tensor_data = self.data.lock().unwrap();
|
|
||||||
#[allow(clippy::map_clone)]
|
|
||||||
tensor_data.values().map(|c| c.clone()).collect::<Vec<_>>()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Save the map in the safetensors format.
|
|
||||||
pub fn save<P: AsRef<std::path::Path>>(&self, path: P) -> Result<()> {
|
|
||||||
let tensor_data = self.data.lock().unwrap();
|
|
||||||
let data = tensor_data.iter().map(|(k, v)| (k, v.as_tensor()));
|
|
||||||
safetensors::tensor::serialize_to_file(data, &None, path.as_ref())?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Load some values from a safetensors file and modify the existing variables to have these
|
|
||||||
/// values.
|
|
||||||
///
|
|
||||||
/// Note that values for variables that are currently not in the map are not kept.
|
|
||||||
pub fn load<P: AsRef<std::path::Path>>(&mut self, path: P) -> Result<()> {
|
|
||||||
let path = path.as_ref();
|
|
||||||
let data = unsafe { candle::safetensors::MmapedFile::new(path)? };
|
|
||||||
let data = data.deserialize()?;
|
|
||||||
let mut tensor_data = self.data.lock().unwrap();
|
|
||||||
for (name, var) in tensor_data.iter_mut() {
|
|
||||||
match data.tensor(name) {
|
|
||||||
Ok(data) => {
|
|
||||||
let data: Tensor = data.load(var.device())?;
|
|
||||||
if let Err(err) = var.set(&data) {
|
|
||||||
candle::bail!("error setting {name} using data from {path:?}: {err}",)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(_) => candle::bail!("cannot find tensor for {name}"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Retrieve or add a new variable.
|
|
||||||
pub fn get<S: Into<Shape>>(
|
|
||||||
&self,
|
|
||||||
shape: S,
|
|
||||||
path: &str,
|
|
||||||
init: crate::Init,
|
|
||||||
dtype: DType,
|
|
||||||
device: &Device,
|
|
||||||
) -> Result<Tensor> {
|
|
||||||
let shape = shape.into();
|
|
||||||
let mut tensor_data = self.data.lock().unwrap();
|
|
||||||
if let Some(tensor) = tensor_data.get(path) {
|
|
||||||
let tensor_shape = tensor.shape();
|
|
||||||
if &shape != tensor_shape {
|
|
||||||
candle::bail!("shape mismatch on {path}: {shape:?} <> {tensor_shape:?}")
|
|
||||||
}
|
|
||||||
return Ok(tensor.as_tensor().clone());
|
|
||||||
}
|
|
||||||
let var = init.var(shape, dtype, device)?;
|
|
||||||
let tensor = var.as_tensor().clone();
|
|
||||||
tensor_data.insert(path.to_string(), var);
|
|
||||||
Ok(tensor)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn data(&self) -> &Mutex<HashMap<String, Var>> {
|
|
||||||
&self.data
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: Maybe we would want the storage to be generic, e.g. with Box<dyn> to avoid too many
|
// TODO: Maybe we would want the storage to be generic, e.g. with Box<dyn> to avoid too many
|
||||||
// generics.
|
// generics.
|
||||||
@ -307,7 +224,7 @@ impl<'a> VarBuilder<'a> {
|
|||||||
Ok(tensor)
|
Ok(tensor)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Retrieve the tensor associted with the current name and path.
|
/// Retrieve the tensor associated with the given name at the current path.
|
||||||
pub fn get<S: Into<Shape>>(&self, s: S, tensor_name: &str) -> Result<Tensor> {
|
pub fn get<S: Into<Shape>>(&self, s: S, tensor_name: &str) -> Result<Tensor> {
|
||||||
let data = self.data.as_ref();
|
let data = self.data.as_ref();
|
||||||
let s: Shape = s.into();
|
let s: Shape = s.into();
|
||||||
@ -324,7 +241,7 @@ impl<'a> VarBuilder<'a> {
|
|||||||
})?
|
})?
|
||||||
.clone(),
|
.clone(),
|
||||||
Tensors::VarMap(varmap) => {
|
Tensors::VarMap(varmap) => {
|
||||||
let data = varmap.data.lock().unwrap();
|
let data = varmap.data().lock().unwrap();
|
||||||
data.get(&path)
|
data.get(&path)
|
||||||
.ok_or_else(|| {
|
.ok_or_else(|| {
|
||||||
Error::CannotFindTensor {
|
Error::CannotFindTensor {
|
||||||
@ -368,8 +285,8 @@ impl<'a> VarBuilder<'a> {
|
|||||||
Ok(tensor)
|
Ok(tensor)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Retrieve the tensor associted with the current name and path or initialize a new tensor if
|
/// Retrieve the tensor associated with the given name at the current path or initialize a new
|
||||||
/// it's missing.
|
/// tensor if it's missing.
|
||||||
///
|
///
|
||||||
/// Tensor initialization is only available if the `VarBuilder` is backed by a `VarMap`.
|
/// Tensor initialization is only available if the `VarBuilder` is backed by a `VarMap`.
|
||||||
pub fn get_or_init<S: Into<Shape>>(
|
pub fn get_or_init<S: Into<Shape>>(
|
||||||
|
87
candle-nn/src/var_map.rs
Normal file
87
candle-nn/src/var_map.rs
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
use candle::{safetensors::Load, DType, Device, Result, Shape, Tensor, Var};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
|
/// A `VarMap` is a store that holds named variables. Variables can be retrieved from the stores
|
||||||
|
/// and new variables can be added by providing some initialization config in case they are
|
||||||
|
/// missing.
|
||||||
|
/// `VarMap` structures can be serialized in the safetensors format.
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct VarMap {
|
||||||
|
data: Arc<Mutex<HashMap<String, Var>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl VarMap {
|
||||||
|
/// Create a new empty `VarMap`.
|
||||||
|
#[allow(clippy::new_without_default)]
|
||||||
|
pub fn new() -> Self {
|
||||||
|
let data = Arc::new(Mutex::new(HashMap::new()));
|
||||||
|
Self { data }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Retrieve all the variables currently stored in the map.
|
||||||
|
pub fn all_vars(&self) -> Vec<Var> {
|
||||||
|
let tensor_data = self.data.lock().unwrap();
|
||||||
|
#[allow(clippy::map_clone)]
|
||||||
|
tensor_data.values().map(|c| c.clone()).collect::<Vec<_>>()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Save the map in the safetensors format.
|
||||||
|
pub fn save<P: AsRef<std::path::Path>>(&self, path: P) -> Result<()> {
|
||||||
|
let tensor_data = self.data.lock().unwrap();
|
||||||
|
let data = tensor_data.iter().map(|(k, v)| (k, v.as_tensor()));
|
||||||
|
safetensors::tensor::serialize_to_file(data, &None, path.as_ref())?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Load some values from a safetensors file and modify the existing variables to have these
|
||||||
|
/// values.
|
||||||
|
///
|
||||||
|
/// Note that values for variables that are currently not in the map are not kept.
|
||||||
|
pub fn load<P: AsRef<std::path::Path>>(&mut self, path: P) -> Result<()> {
|
||||||
|
let path = path.as_ref();
|
||||||
|
let data = unsafe { candle::safetensors::MmapedFile::new(path)? };
|
||||||
|
let data = data.deserialize()?;
|
||||||
|
let mut tensor_data = self.data.lock().unwrap();
|
||||||
|
for (name, var) in tensor_data.iter_mut() {
|
||||||
|
match data.tensor(name) {
|
||||||
|
Ok(data) => {
|
||||||
|
let data: Tensor = data.load(var.device())?;
|
||||||
|
if let Err(err) = var.set(&data) {
|
||||||
|
candle::bail!("error setting {name} using data from {path:?}: {err}",)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(_) => candle::bail!("cannot find tensor for {name}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Retrieve or add a new variable.
|
||||||
|
pub fn get<S: Into<Shape>>(
|
||||||
|
&self,
|
||||||
|
shape: S,
|
||||||
|
path: &str,
|
||||||
|
init: crate::Init,
|
||||||
|
dtype: DType,
|
||||||
|
device: &Device,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
let shape = shape.into();
|
||||||
|
let mut tensor_data = self.data.lock().unwrap();
|
||||||
|
if let Some(tensor) = tensor_data.get(path) {
|
||||||
|
let tensor_shape = tensor.shape();
|
||||||
|
if &shape != tensor_shape {
|
||||||
|
candle::bail!("shape mismatch on {path}: {shape:?} <> {tensor_shape:?}")
|
||||||
|
}
|
||||||
|
return Ok(tensor.as_tensor().clone());
|
||||||
|
}
|
||||||
|
let var = init.var(shape, dtype, device)?;
|
||||||
|
let tensor = var.as_tensor().clone();
|
||||||
|
tensor_data.insert(path.to_string(), var);
|
||||||
|
Ok(tensor)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn data(&self) -> &Mutex<HashMap<String, Var>> {
|
||||||
|
&self.data
|
||||||
|
}
|
||||||
|
}
|
Reference in New Issue
Block a user