From 37dbbff261f1641db6dc868fc4dded5f8cb25a1f Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Wed, 11 Oct 2023 09:16:04 +0200 Subject: [PATCH] Use full tensors for zeros and ones (#1071) * Only optimize float tensors. * Use full tensors for zeros and ones. --- candle-core/src/tensor.rs | 22 ++++++---------------- 1 file changed, 6 insertions(+), 16 deletions(-) diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 7295c350..e2c97af2 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -177,14 +177,9 @@ impl Tensor { is_variable: bool, ) -> Result { let none = BackpropOp::none(); - if is_variable { - let shape = shape.into(); - let storage = device.ones(&shape, dtype)?; - Ok(from_storage(storage, shape, none, is_variable)) - } else { - let storage = device.ones(&crate::shape::SCALAR, dtype)?; - from_storage(storage, crate::shape::SCALAR, none, is_variable).broadcast_as(shape) - } + let shape = shape.into(); + let storage = device.ones(&shape, dtype)?; + Ok(from_storage(storage, shape, none, is_variable)) } /// Creates a new tensor filled with ones. @@ -222,14 +217,9 @@ impl Tensor { is_variable: bool, ) -> Result { let none = BackpropOp::none(); - if is_variable { - let shape = shape.into(); - let storage = device.zeros(&shape, dtype)?; - Ok(from_storage(storage, shape, none, is_variable)) - } else { - let storage = device.zeros(&crate::shape::SCALAR, dtype)?; - from_storage(storage, crate::shape::SCALAR, none, is_variable).broadcast_as(shape) - } + let shape = shape.into(); + let storage = device.zeros(&shape, dtype)?; + Ok(from_storage(storage, shape, none, is_variable)) } /// Creates a new tensor filled with zeros.