mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
VarMap setter functions (#938)
* Add some setter helper functions for varmap. * Add more comments.
This commit is contained in:
@ -57,6 +57,44 @@ impl VarMap {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Set a named variable to some value.
|
||||||
|
pub fn set_one<K: AsRef<String>, V: AsRef<Tensor>>(&mut self, name: K, value: V) -> Result<()> {
|
||||||
|
let tensor_data = self.data.lock().unwrap();
|
||||||
|
let name = name.as_ref();
|
||||||
|
match tensor_data.get(name) {
|
||||||
|
None => candle::bail!("cannot find {name} in VarMap"),
|
||||||
|
Some(var) => {
|
||||||
|
if let Err(err) = var.set(value.as_ref()) {
|
||||||
|
candle::bail!("error setting {name}: {err}",)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set some named variables to some values.
|
||||||
|
///
|
||||||
|
/// If an error is returned, some of the variables might have already been set to their new
|
||||||
|
/// values.
|
||||||
|
pub fn set<I: Iterator<Item = (K, V)>, K: AsRef<String>, V: AsRef<Tensor>>(
|
||||||
|
&mut self,
|
||||||
|
iter: I,
|
||||||
|
) -> Result<()> {
|
||||||
|
let tensor_data = self.data.lock().unwrap();
|
||||||
|
for (name, value) in iter {
|
||||||
|
let name = name.as_ref();
|
||||||
|
match tensor_data.get(name) {
|
||||||
|
None => candle::bail!("cannot find {name} in VarMap"),
|
||||||
|
Some(var) => {
|
||||||
|
if let Err(err) = var.set(value.as_ref()) {
|
||||||
|
candle::bail!("error setting {name}: {err}",)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
/// Retrieve or add a new variable.
|
/// Retrieve or add a new variable.
|
||||||
pub fn get<S: Into<Shape>>(
|
pub fn get<S: Into<Shape>>(
|
||||||
&self,
|
&self,
|
||||||
|
Reference in New Issue
Block a user