diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 7295c350..e2c97af2 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -177,14 +177,9 @@ impl Tensor { is_variable: bool, ) -> Result { let none = BackpropOp::none(); - if is_variable { - let shape = shape.into(); - let storage = device.ones(&shape, dtype)?; - Ok(from_storage(storage, shape, none, is_variable)) - } else { - let storage = device.ones(&crate::shape::SCALAR, dtype)?; - from_storage(storage, crate::shape::SCALAR, none, is_variable).broadcast_as(shape) - } + let shape = shape.into(); + let storage = device.ones(&shape, dtype)?; + Ok(from_storage(storage, shape, none, is_variable)) } /// Creates a new tensor filled with ones. @@ -222,14 +217,9 @@ impl Tensor { is_variable: bool, ) -> Result { let none = BackpropOp::none(); - if is_variable { - let shape = shape.into(); - let storage = device.zeros(&shape, dtype)?; - Ok(from_storage(storage, shape, none, is_variable)) - } else { - let storage = device.zeros(&crate::shape::SCALAR, dtype)?; - from_storage(storage, crate::shape::SCALAR, none, is_variable).broadcast_as(shape) - } + let shape = shape.into(); + let storage = device.zeros(&shape, dtype)?; + Ok(from_storage(storage, shape, none, is_variable)) } /// Creates a new tensor filled with zeros.