mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
VarMap setter functions (#938)
* Add some setter helper functions for varmap. * Add more comments.
This commit is contained in:
@ -57,6 +57,44 @@ impl VarMap {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Set a named variable to some value.
|
||||
pub fn set_one<K: AsRef<String>, V: AsRef<Tensor>>(&mut self, name: K, value: V) -> Result<()> {
|
||||
let tensor_data = self.data.lock().unwrap();
|
||||
let name = name.as_ref();
|
||||
match tensor_data.get(name) {
|
||||
None => candle::bail!("cannot find {name} in VarMap"),
|
||||
Some(var) => {
|
||||
if let Err(err) = var.set(value.as_ref()) {
|
||||
candle::bail!("error setting {name}: {err}",)
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Set some named variables to some values.
|
||||
///
|
||||
/// If an error is returned, some of the variables might have already been set to their new
|
||||
/// values.
|
||||
pub fn set<I: Iterator<Item = (K, V)>, K: AsRef<String>, V: AsRef<Tensor>>(
|
||||
&mut self,
|
||||
iter: I,
|
||||
) -> Result<()> {
|
||||
let tensor_data = self.data.lock().unwrap();
|
||||
for (name, value) in iter {
|
||||
let name = name.as_ref();
|
||||
match tensor_data.get(name) {
|
||||
None => candle::bail!("cannot find {name} in VarMap"),
|
||||
Some(var) => {
|
||||
if let Err(err) = var.set(value.as_ref()) {
|
||||
candle::bail!("error setting {name}: {err}",)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Retrieve or add a new variable.
|
||||
pub fn get<S: Into<Shape>>(
|
||||
&self,
|
||||
|
Reference in New Issue
Block a user