Bug Fix: When converting a tensor to a variable, clone if the tensor is already a variable. (#2124)

* When converting a tensor to a variable, clone if the tensor is already a variable.

* Add a test to ensure training a batch norm works with VarMaps

---------

Co-authored-by: Jeffrey Dallatezza <jeffreydallatezza@Jeffreys-Laptop.local>
This commit is contained in:
Jeffrey Dallatezza
2024-04-29 02:21:53 -07:00
committed by GitHub
parent 3bbb88fcb4
commit a0d03aded1
2 changed files with 51 additions and 4 deletions

View File

@ -34,9 +34,14 @@ impl Var {
Ok(Self(inner))
}
// Convert a tensor to a variable, if the tensor is already a variable then it is returned as is.
pub fn from_tensor(t: &Tensor) -> Result<Self> {
let inner = t.make_var()?;
Ok(Self(inner))
if t.is_variable() {
Ok(Self(t.clone()))
} else {
let inner = t.make_var()?;
Ok(Self(inner))
}
}
pub fn rand_f64<S: Into<Shape>>(