diff --git a/candle-nn/src/var_map.rs b/candle-nn/src/var_map.rs index f61fad23..c17558b7 100644 --- a/candle-nn/src/var_map.rs +++ b/candle-nn/src/var_map.rs @@ -57,6 +57,44 @@ impl VarMap { Ok(()) } + /// Set a named variable to some value. + pub fn set_one, V: AsRef>(&mut self, name: K, value: V) -> Result<()> { + let tensor_data = self.data.lock().unwrap(); + let name = name.as_ref(); + match tensor_data.get(name) { + None => candle::bail!("cannot find {name} in VarMap"), + Some(var) => { + if let Err(err) = var.set(value.as_ref()) { + candle::bail!("error setting {name}: {err}",) + } + } + } + Ok(()) + } + + /// Set some named variables to some values. + /// + /// If an error is returned, some of the variables might have already been set to their new + /// values. + pub fn set, K: AsRef, V: AsRef>( + &mut self, + iter: I, + ) -> Result<()> { + let tensor_data = self.data.lock().unwrap(); + for (name, value) in iter { + let name = name.as_ref(); + match tensor_data.get(name) { + None => candle::bail!("cannot find {name} in VarMap"), + Some(var) => { + if let Err(err) = var.set(value.as_ref()) { + candle::bail!("error setting {name}: {err}",) + } + } + } + } + Ok(()) + } + /// Retrieve or add a new variable. pub fn get>( &self,