Only keep track of the graph when needed.

This commit is contained in:
laurent
2023-06-22 22:06:24 +01:00
parent fc83d97b41
commit 552276749a

View File

@ -49,12 +49,17 @@ macro_rules! unary_op {
let storage = self let storage = self
.storage .storage
.unary_impl::<crate::op::$op_name>(self.shape(), self.stride())?; .unary_impl::<crate::op::$op_name>(self.shape(), self.stride())?;
let op = if self.track_op() {
Some(Op::$op_name(self.clone()))
} else {
None
};
let tensor_ = Tensor_ { let tensor_ = Tensor_ {
id: TensorId::new(), id: TensorId::new(),
storage, storage,
shape: shape.clone(), shape: shape.clone(),
stride: shape.stride_contiguous(), stride: shape.stride_contiguous(),
op: Some(Op::$op_name(self.clone())), op,
is_variable: false, is_variable: false,
}; };
Ok(Self(Arc::new(tensor_))) Ok(Self(Arc::new(tensor_)))
@ -72,12 +77,17 @@ macro_rules! binary_op {
self.stride(), self.stride(),
rhs.stride(), rhs.stride(),
)?; )?;
let op = if self.track_op() || rhs.track_op() {
Some(Op::$op_name(self.clone(), rhs.clone()))
} else {
None
};
let tensor_ = Tensor_ { let tensor_ = Tensor_ {
id: TensorId::new(), id: TensorId::new(),
storage, storage,
shape: shape.clone(), shape: shape.clone(),
stride: shape.stride_contiguous(), stride: shape.stride_contiguous(),
op: Some(Op::$op_name(self.clone(), rhs.clone())), op,
is_variable: false, is_variable: false,
}; };
Ok(Self(Arc::new(tensor_))) Ok(Self(Arc::new(tensor_)))
@ -214,6 +224,12 @@ impl Tensor {
} }
} }
/// Returns true if the computation graph should track this op, that is if it is
/// a variable or if it has some variable as dependencies.
pub(crate) fn track_op(&self) -> bool {
self.is_variable || self.op.is_some()
}
// TODO: Also make an inplace version or a pre-allocated? This could be tricky // TODO: Also make an inplace version or a pre-allocated? This could be tricky
// if this can create cycles in the compute graph. // if this can create cycles in the compute graph.
binary_op!(add, Add); binary_op!(add, Add);
@ -247,16 +263,21 @@ impl Tensor {
let storage = self let storage = self
.storage .storage
.affine_impl(self.shape(), self.stride(), mul, add)?; .affine_impl(self.shape(), self.stride(), mul, add)?;
let op = if self.track_op() {
Some(Op::Affine {
arg: self.clone(),
mul,
add,
})
} else {
None
};
let tensor_ = Tensor_ { let tensor_ = Tensor_ {
id: TensorId::new(), id: TensorId::new(),
storage, storage,
shape: shape.clone(), shape: shape.clone(),
stride: shape.stride_contiguous(), stride: shape.stride_contiguous(),
op: Some(Op::Affine { op,
arg: self.clone(),
mul,
add,
}),
is_variable: false, is_variable: false,
}; };
Ok(Self(Arc::new(tensor_))) Ok(Self(Arc::new(tensor_)))
@ -307,12 +328,17 @@ impl Tensor {
self.stride(), self.stride(),
rhs.stride(), rhs.stride(),
)?; )?;
let op = if self.track_op() || rhs.track_op() {
Some(Op::Matmul(self.clone(), rhs.clone()))
} else {
None
};
let tensor_ = Tensor_ { let tensor_ = Tensor_ {
id: TensorId::new(), id: TensorId::new(),
storage, storage,
shape: c_shape, shape: c_shape,
stride: c_stride, stride: c_stride,
op: Some(Op::Matmul(self.clone(), rhs.clone())), op,
is_variable: false, is_variable: false,
}; };
Ok(Self(Arc::new(tensor_))) Ok(Self(Arc::new(tensor_)))