From d2622a8160a429835ab0a3aa2145ab8c9de4cdd9 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sun, 20 Aug 2023 14:25:07 +0100 Subject: [PATCH] Move the VarMap to a separate file (#525) * Move the var-map struct in a separate file. * Fix some typos. --- candle-nn/src/lib.rs | 4 +- candle-nn/src/var_builder.rs | 97 +++--------------------------------- candle-nn/src/var_map.rs | 87 ++++++++++++++++++++++++++++++++ 3 files changed, 97 insertions(+), 91 deletions(-) create mode 100644 candle-nn/src/var_map.rs diff --git a/candle-nn/src/lib.rs b/candle-nn/src/lib.rs index 7b486f71..e195ac67 100644 --- a/candle-nn/src/lib.rs +++ b/candle-nn/src/lib.rs @@ -12,6 +12,7 @@ pub mod loss; pub mod ops; pub mod optim; pub mod var_builder; +pub mod var_map; pub use activation::Activation; pub use batch_norm::{batch_norm, BatchNorm, BatchNormConfig}; @@ -22,7 +23,8 @@ pub use init::Init; pub use layer_norm::{layer_norm, rms_norm, LayerNorm, LayerNormConfig, RmsNorm}; pub use linear::{linear, linear_no_bias, Linear}; pub use optim::{AdamW, ParamsAdamW, SGD}; -pub use var_builder::{VarBuilder, VarMap}; +pub use var_builder::VarBuilder; +pub use var_map::VarMap; // A simple trait defining a module with forward method using a single argument. pub trait Module: std::fmt::Debug { diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs index 311a6fb9..b24ed56d 100644 --- a/candle-nn/src/var_builder.rs +++ b/candle-nn/src/var_builder.rs @@ -1,91 +1,8 @@ -use candle::{safetensors::Load, DType, Device, Error, Result, Shape, Tensor, Var}; +use crate::VarMap; +use candle::{safetensors::Load, DType, Device, Error, Result, Shape, Tensor}; use safetensors::{slice::IndexOp, tensor::SafeTensors}; use std::collections::HashMap; -use std::sync::{Arc, Mutex}; - -/// A `VarMap` is a store that holds named variables. Variables can be retrieved from the stores -/// and new variables can be added by providing some initialization config in case they are -/// missing. -/// `VarMap` structures can be serialized in the safetensors format. -#[derive(Clone)] -pub struct VarMap { - data: Arc>>, -} - -impl VarMap { - /// Create a new empty `VarMap`. - #[allow(clippy::new_without_default)] - pub fn new() -> Self { - let data = Arc::new(Mutex::new(HashMap::new())); - Self { data } - } - - /// Retrieve all the variables currently stored in the map. - pub fn all_vars(&self) -> Vec { - let tensor_data = self.data.lock().unwrap(); - #[allow(clippy::map_clone)] - tensor_data.values().map(|c| c.clone()).collect::>() - } - - /// Save the map in the safetensors format. - pub fn save>(&self, path: P) -> Result<()> { - let tensor_data = self.data.lock().unwrap(); - let data = tensor_data.iter().map(|(k, v)| (k, v.as_tensor())); - safetensors::tensor::serialize_to_file(data, &None, path.as_ref())?; - Ok(()) - } - - /// Load some values from a safetensors file and modify the existing variables to have these - /// values. - /// - /// Note that values for variables that are currently not in the map are not kept. - pub fn load>(&mut self, path: P) -> Result<()> { - let path = path.as_ref(); - let data = unsafe { candle::safetensors::MmapedFile::new(path)? }; - let data = data.deserialize()?; - let mut tensor_data = self.data.lock().unwrap(); - for (name, var) in tensor_data.iter_mut() { - match data.tensor(name) { - Ok(data) => { - let data: Tensor = data.load(var.device())?; - if let Err(err) = var.set(&data) { - candle::bail!("error setting {name} using data from {path:?}: {err}",) - } - } - Err(_) => candle::bail!("cannot find tensor for {name}"), - } - } - Ok(()) - } - - /// Retrieve or add a new variable. - pub fn get>( - &self, - shape: S, - path: &str, - init: crate::Init, - dtype: DType, - device: &Device, - ) -> Result { - let shape = shape.into(); - let mut tensor_data = self.data.lock().unwrap(); - if let Some(tensor) = tensor_data.get(path) { - let tensor_shape = tensor.shape(); - if &shape != tensor_shape { - candle::bail!("shape mismatch on {path}: {shape:?} <> {tensor_shape:?}") - } - return Ok(tensor.as_tensor().clone()); - } - let var = init.var(shape, dtype, device)?; - let tensor = var.as_tensor().clone(); - tensor_data.insert(path.to_string(), var); - Ok(tensor) - } - - pub fn data(&self) -> &Mutex> { - &self.data - } -} +use std::sync::Arc; // TODO: Maybe we would want the storage to be generic, e.g. with Box to avoid too many // generics. @@ -307,7 +224,7 @@ impl<'a> VarBuilder<'a> { Ok(tensor) } - /// Retrieve the tensor associted with the current name and path. + /// Retrieve the tensor associated with the given name at the current path. pub fn get>(&self, s: S, tensor_name: &str) -> Result { let data = self.data.as_ref(); let s: Shape = s.into(); @@ -324,7 +241,7 @@ impl<'a> VarBuilder<'a> { })? .clone(), Tensors::VarMap(varmap) => { - let data = varmap.data.lock().unwrap(); + let data = varmap.data().lock().unwrap(); data.get(&path) .ok_or_else(|| { Error::CannotFindTensor { @@ -368,8 +285,8 @@ impl<'a> VarBuilder<'a> { Ok(tensor) } - /// Retrieve the tensor associted with the current name and path or initialize a new tensor if - /// it's missing. + /// Retrieve the tensor associated with the given name at the current path or initialize a new + /// tensor if it's missing. /// /// Tensor initialization is only available if the `VarBuilder` is backed by a `VarMap`. pub fn get_or_init>( diff --git a/candle-nn/src/var_map.rs b/candle-nn/src/var_map.rs new file mode 100644 index 00000000..f61fad23 --- /dev/null +++ b/candle-nn/src/var_map.rs @@ -0,0 +1,87 @@ +use candle::{safetensors::Load, DType, Device, Result, Shape, Tensor, Var}; +use std::collections::HashMap; +use std::sync::{Arc, Mutex}; + +/// A `VarMap` is a store that holds named variables. Variables can be retrieved from the stores +/// and new variables can be added by providing some initialization config in case they are +/// missing. +/// `VarMap` structures can be serialized in the safetensors format. +#[derive(Clone)] +pub struct VarMap { + data: Arc>>, +} + +impl VarMap { + /// Create a new empty `VarMap`. + #[allow(clippy::new_without_default)] + pub fn new() -> Self { + let data = Arc::new(Mutex::new(HashMap::new())); + Self { data } + } + + /// Retrieve all the variables currently stored in the map. + pub fn all_vars(&self) -> Vec { + let tensor_data = self.data.lock().unwrap(); + #[allow(clippy::map_clone)] + tensor_data.values().map(|c| c.clone()).collect::>() + } + + /// Save the map in the safetensors format. + pub fn save>(&self, path: P) -> Result<()> { + let tensor_data = self.data.lock().unwrap(); + let data = tensor_data.iter().map(|(k, v)| (k, v.as_tensor())); + safetensors::tensor::serialize_to_file(data, &None, path.as_ref())?; + Ok(()) + } + + /// Load some values from a safetensors file and modify the existing variables to have these + /// values. + /// + /// Note that values for variables that are currently not in the map are not kept. + pub fn load>(&mut self, path: P) -> Result<()> { + let path = path.as_ref(); + let data = unsafe { candle::safetensors::MmapedFile::new(path)? }; + let data = data.deserialize()?; + let mut tensor_data = self.data.lock().unwrap(); + for (name, var) in tensor_data.iter_mut() { + match data.tensor(name) { + Ok(data) => { + let data: Tensor = data.load(var.device())?; + if let Err(err) = var.set(&data) { + candle::bail!("error setting {name} using data from {path:?}: {err}",) + } + } + Err(_) => candle::bail!("cannot find tensor for {name}"), + } + } + Ok(()) + } + + /// Retrieve or add a new variable. + pub fn get>( + &self, + shape: S, + path: &str, + init: crate::Init, + dtype: DType, + device: &Device, + ) -> Result { + let shape = shape.into(); + let mut tensor_data = self.data.lock().unwrap(); + if let Some(tensor) = tensor_data.get(path) { + let tensor_shape = tensor.shape(); + if &shape != tensor_shape { + candle::bail!("shape mismatch on {path}: {shape:?} <> {tensor_shape:?}") + } + return Ok(tensor.as_tensor().clone()); + } + let var = init.var(shape, dtype, device)?; + let tensor = var.as_tensor().clone(); + tensor_data.insert(path.to_string(), var); + Ok(tensor) + } + + pub fn data(&self) -> &Mutex> { + &self.data + } +}