mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Bug Fix: When converting a tensor to a variable, clone if the tensor is already a variable. (#2124)
* When converting a tensor to a variable, clone if the tensor is already a variable. * Add a test to ensure training a batch norm works with VarMaps --------- Co-authored-by: Jeffrey Dallatezza <jeffreydallatezza@Jeffreys-Laptop.local>
This commit is contained in:

committed by
GitHub

parent
3bbb88fcb4
commit
a0d03aded1
@ -34,9 +34,14 @@ impl Var {
|
||||
Ok(Self(inner))
|
||||
}
|
||||
|
||||
// Convert a tensor to a variable, if the tensor is already a variable then it is returned as is.
|
||||
pub fn from_tensor(t: &Tensor) -> Result<Self> {
|
||||
let inner = t.make_var()?;
|
||||
Ok(Self(inner))
|
||||
if t.is_variable() {
|
||||
Ok(Self(t.clone()))
|
||||
} else {
|
||||
let inner = t.make_var()?;
|
||||
Ok(Self(inner))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn rand_f64<S: Into<Shape>>(
|
||||
|
Reference in New Issue
Block a user