Only keep track of the graph when needed.

This commit is contained in:
laurent
2023-06-22 22:06:24 +01:00
parent fc83d97b41
commit 552276749a

View File

@ -49,12 +49,17 @@ macro_rules! unary_op {
let storage = self
.storage
.unary_impl::<crate::op::$op_name>(self.shape(), self.stride())?;
let op = if self.track_op() {
Some(Op::$op_name(self.clone()))
} else {
None
};
let tensor_ = Tensor_ {
id: TensorId::new(),
storage,
shape: shape.clone(),
stride: shape.stride_contiguous(),
op: Some(Op::$op_name(self.clone())),
op,
is_variable: false,
};
Ok(Self(Arc::new(tensor_)))
@ -72,12 +77,17 @@ macro_rules! binary_op {
self.stride(),
rhs.stride(),
)?;
let op = if self.track_op() || rhs.track_op() {
Some(Op::$op_name(self.clone(), rhs.clone()))
} else {
None
};
let tensor_ = Tensor_ {
id: TensorId::new(),
storage,
shape: shape.clone(),
stride: shape.stride_contiguous(),
op: Some(Op::$op_name(self.clone(), rhs.clone())),
op,
is_variable: false,
};
Ok(Self(Arc::new(tensor_)))
@ -214,6 +224,12 @@ impl Tensor {
}
}
/// Returns true if the computation graph should track this op, that is if it is
/// a variable or if it has some variable as dependencies.
pub(crate) fn track_op(&self) -> bool {
self.is_variable || self.op.is_some()
}
// TODO: Also make an inplace version or a pre-allocated? This could be tricky
// if this can create cycles in the compute graph.
binary_op!(add, Add);
@ -247,16 +263,21 @@ impl Tensor {
let storage = self
.storage
.affine_impl(self.shape(), self.stride(), mul, add)?;
let op = if self.track_op() {
Some(Op::Affine {
arg: self.clone(),
mul,
add,
})
} else {
None
};
let tensor_ = Tensor_ {
id: TensorId::new(),
storage,
shape: shape.clone(),
stride: shape.stride_contiguous(),
op: Some(Op::Affine {
arg: self.clone(),
mul,
add,
}),
op,
is_variable: false,
};
Ok(Self(Arc::new(tensor_)))
@ -307,12 +328,17 @@ impl Tensor {
self.stride(),
rhs.stride(),
)?;
let op = if self.track_op() || rhs.track_op() {
Some(Op::Matmul(self.clone(), rhs.clone()))
} else {
None
};
let tensor_ = Tensor_ {
id: TensorId::new(),
storage,
shape: c_shape,
stride: c_stride,
op: Some(Op::Matmul(self.clone(), rhs.clone())),
op,
is_variable: false,
};
Ok(Self(Arc::new(tensor_)))