From 2741b39ad37ecb58c110459739ee174fae5f1fa4 Mon Sep 17 00:00:00 2001 From: laurent Date: Thu, 29 Jun 2023 11:56:40 +0100 Subject: [PATCH] Use broadcasted scalars for const tensors. --- candle-core/src/backprop.rs | 2 +- candle-core/src/shape.rs | 2 ++ candle-core/src/tensor.rs | 17 +++++++---------- candle-core/tests/grad_tests.rs | 1 - candle-core/tests/tensor_tests.rs | 4 ++-- 5 files changed, 12 insertions(+), 14 deletions(-) diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index 7801b878..45448505 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -99,7 +99,7 @@ impl Tensor { pub fn backward(&self) -> Result { let sorted_nodes = self.sorted_nodes(); let mut grads = GradStore::new(); - grads.insert(self, self.ones_like()?); + grads.insert(self, self.ones_like()?.contiguous()?); for node in sorted_nodes.iter() { if node.is_variable() { continue; diff --git a/candle-core/src/shape.rs b/candle-core/src/shape.rs index efea723b..cc068004 100644 --- a/candle-core/src/shape.rs +++ b/candle-core/src/shape.rs @@ -3,6 +3,8 @@ use crate::{Error, Result}; #[derive(Clone, PartialEq, Eq)] pub struct Shape(Vec); +pub const SCALAR: Shape = Shape(vec![]); + impl std::fmt::Debug for Shape { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{:?}", &self.dims()) diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 4b9b3306..6586834c 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -115,16 +115,14 @@ fn from_storage>( } impl Tensor { - // TODO: Maybe this should be a broadcast rather than actually creating the full tensor. fn ones_impl>( shape: S, dtype: DType, device: &Device, is_variable: bool, ) -> Result { - let shape = shape.into(); - let storage = device.ones(&shape, dtype)?; - Ok(from_storage(storage, shape, None, is_variable)) + let storage = device.ones(&crate::shape::SCALAR, dtype)?; + from_storage(storage, crate::shape::SCALAR, None, is_variable).broadcast_as(shape) } pub fn ones>(shape: S, dtype: DType, device: &Device) -> Result { @@ -132,6 +130,8 @@ impl Tensor { } pub fn ones_var>(shape: S, dtype: DType, device: &Device) -> Result { + // Maybe we should allocate some actual storage for vars rather than just using a + // broadcasted scalar? Self::ones_impl(shape, dtype, device, true) } @@ -139,16 +139,14 @@ impl Tensor { Tensor::ones(self.shape(), self.dtype(), &self.device()) } - // TODO: Maybe this should be a broadcast rather than actually creating the full tensor. fn zeros_impl>( shape: S, dtype: DType, device: &Device, is_variable: bool, ) -> Result { - let shape = shape.into(); - let storage = device.zeros(&shape, dtype)?; - Ok(from_storage(storage, shape, None, is_variable)) + let storage = device.zeros(&crate::shape::SCALAR, dtype)?; + from_storage(storage, crate::shape::SCALAR, None, is_variable).broadcast_as(shape) } pub fn zeros>(shape: S, dtype: DType, device: &Device) -> Result { @@ -599,8 +597,7 @@ impl Tensor { &self.layout } - // TODO: Rename to `stride` once the PR that introduced the layout has been merged. - pub fn stride_tmp(&self) -> &[usize] { + pub fn stride(&self) -> &[usize] { self.layout.stride() } diff --git a/candle-core/tests/grad_tests.rs b/candle-core/tests/grad_tests.rs index 8632ec21..10eef780 100644 --- a/candle-core/tests/grad_tests.rs +++ b/candle-core/tests/grad_tests.rs @@ -20,7 +20,6 @@ fn matmul_grad(device: &Device) -> Result<()> { let x = Tensor::var_from_slice(&data, (2, 2, 3), device)?; let data: Vec<_> = (0..12).map(|i| i as f32).collect(); let y = Tensor::var_from_slice(&data, (2, 3, 2), device)?; - let c = x.matmul(&y)?; let grads = c.backward()?; let grad_x = grads.get(&x).context("no grad for x")?; diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index 8ac0c9f2..78ca4b05 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -263,12 +263,12 @@ fn matmul(device: &Device) -> Result<()> { let a_tt = a.t()?.contiguous()?.t()?; assert!(!a_tt.is_contiguous()); assert_eq!(a.dims(), a_tt.dims()); - assert_eq!(a_tt.stride_tmp(), &[6, 1, 2]); + assert_eq!(a_tt.stride(), &[6, 1, 2]); let b_tt = b.t()?.contiguous()?.t()?; assert!(!b_tt.is_contiguous()); assert_eq!(b.dims(), b_tt.dims()); - assert_eq!(b_tt.stride_tmp(), &[6, 1, 3]); + assert_eq!(b_tt.stride(), &[6, 1, 3]); assert_eq!(a_tt.matmul(&b)?.to_vec3::()?, &expected); assert_eq!(a.matmul(&b_tt)?.to_vec3::()?, &expected);