VarMap setter functions (#938)

* Add some setter helper functions for varmap.

* Add more comments.
This commit is contained in:
Laurent Mazare
2023-09-23 10:27:51 +01:00
committed by GitHub
parent 7582937a32
commit 402d207f0f

View File

@ -57,6 +57,44 @@ impl VarMap {
Ok(())
}
/// Set a named variable to some value.
pub fn set_one<K: AsRef<String>, V: AsRef<Tensor>>(&mut self, name: K, value: V) -> Result<()> {
let tensor_data = self.data.lock().unwrap();
let name = name.as_ref();
match tensor_data.get(name) {
None => candle::bail!("cannot find {name} in VarMap"),
Some(var) => {
if let Err(err) = var.set(value.as_ref()) {
candle::bail!("error setting {name}: {err}",)
}
}
}
Ok(())
}
/// Set some named variables to some values.
///
/// If an error is returned, some of the variables might have already been set to their new
/// values.
pub fn set<I: Iterator<Item = (K, V)>, K: AsRef<String>, V: AsRef<Tensor>>(
&mut self,
iter: I,
) -> Result<()> {
let tensor_data = self.data.lock().unwrap();
for (name, value) in iter {
let name = name.as_ref();
match tensor_data.get(name) {
None => candle::bail!("cannot find {name} in VarMap"),
Some(var) => {
if let Err(err) = var.set(value.as_ref()) {
candle::bail!("error setting {name}: {err}",)
}
}
}
}
Ok(())
}
/// Retrieve or add a new variable.
pub fn get<S: Into<Shape>>(
&self,