From 552276749a6d2ec38d89505984c9989607cbb9d1 Mon Sep 17 00:00:00 2001 From: laurent Date: Thu, 22 Jun 2023 22:06:24 +0100 Subject: [PATCH] Only keep track of the graph when needed. --- src/tensor.rs | 42 ++++++++++++++++++++++++++++++++++-------- 1 file changed, 34 insertions(+), 8 deletions(-) diff --git a/src/tensor.rs b/src/tensor.rs index 78359144..39eba29b 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -49,12 +49,17 @@ macro_rules! unary_op { let storage = self .storage .unary_impl::(self.shape(), self.stride())?; + let op = if self.track_op() { + Some(Op::$op_name(self.clone())) + } else { + None + }; let tensor_ = Tensor_ { id: TensorId::new(), storage, shape: shape.clone(), stride: shape.stride_contiguous(), - op: Some(Op::$op_name(self.clone())), + op, is_variable: false, }; Ok(Self(Arc::new(tensor_))) @@ -72,12 +77,17 @@ macro_rules! binary_op { self.stride(), rhs.stride(), )?; + let op = if self.track_op() || rhs.track_op() { + Some(Op::$op_name(self.clone(), rhs.clone())) + } else { + None + }; let tensor_ = Tensor_ { id: TensorId::new(), storage, shape: shape.clone(), stride: shape.stride_contiguous(), - op: Some(Op::$op_name(self.clone(), rhs.clone())), + op, is_variable: false, }; Ok(Self(Arc::new(tensor_))) @@ -214,6 +224,12 @@ impl Tensor { } } + /// Returns true if the computation graph should track this op, that is if it is + /// a variable or if it has some variable as dependencies. + pub(crate) fn track_op(&self) -> bool { + self.is_variable || self.op.is_some() + } + // TODO: Also make an inplace version or a pre-allocated? This could be tricky // if this can create cycles in the compute graph. binary_op!(add, Add); @@ -247,16 +263,21 @@ impl Tensor { let storage = self .storage .affine_impl(self.shape(), self.stride(), mul, add)?; + let op = if self.track_op() { + Some(Op::Affine { + arg: self.clone(), + mul, + add, + }) + } else { + None + }; let tensor_ = Tensor_ { id: TensorId::new(), storage, shape: shape.clone(), stride: shape.stride_contiguous(), - op: Some(Op::Affine { - arg: self.clone(), - mul, - add, - }), + op, is_variable: false, }; Ok(Self(Arc::new(tensor_))) @@ -307,12 +328,17 @@ impl Tensor { self.stride(), rhs.stride(), )?; + let op = if self.track_op() || rhs.track_op() { + Some(Op::Matmul(self.clone(), rhs.clone())) + } else { + None + }; let tensor_ = Tensor_ { id: TensorId::new(), storage, shape: c_shape, stride: c_stride, - op: Some(Op::Matmul(self.clone(), rhs.clone())), + op, is_variable: false, }; Ok(Self(Arc::new(tensor_)))