mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
Only keep track of the graph when needed.
This commit is contained in:
@ -49,12 +49,17 @@ macro_rules! unary_op {
|
||||
let storage = self
|
||||
.storage
|
||||
.unary_impl::<crate::op::$op_name>(self.shape(), self.stride())?;
|
||||
let op = if self.track_op() {
|
||||
Some(Op::$op_name(self.clone()))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
storage,
|
||||
shape: shape.clone(),
|
||||
stride: shape.stride_contiguous(),
|
||||
op: Some(Op::$op_name(self.clone())),
|
||||
op,
|
||||
is_variable: false,
|
||||
};
|
||||
Ok(Self(Arc::new(tensor_)))
|
||||
@ -72,12 +77,17 @@ macro_rules! binary_op {
|
||||
self.stride(),
|
||||
rhs.stride(),
|
||||
)?;
|
||||
let op = if self.track_op() || rhs.track_op() {
|
||||
Some(Op::$op_name(self.clone(), rhs.clone()))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
storage,
|
||||
shape: shape.clone(),
|
||||
stride: shape.stride_contiguous(),
|
||||
op: Some(Op::$op_name(self.clone(), rhs.clone())),
|
||||
op,
|
||||
is_variable: false,
|
||||
};
|
||||
Ok(Self(Arc::new(tensor_)))
|
||||
@ -214,6 +224,12 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns true if the computation graph should track this op, that is if it is
|
||||
/// a variable or if it has some variable as dependencies.
|
||||
pub(crate) fn track_op(&self) -> bool {
|
||||
self.is_variable || self.op.is_some()
|
||||
}
|
||||
|
||||
// TODO: Also make an inplace version or a pre-allocated? This could be tricky
|
||||
// if this can create cycles in the compute graph.
|
||||
binary_op!(add, Add);
|
||||
@ -247,16 +263,21 @@ impl Tensor {
|
||||
let storage = self
|
||||
.storage
|
||||
.affine_impl(self.shape(), self.stride(), mul, add)?;
|
||||
let op = if self.track_op() {
|
||||
Some(Op::Affine {
|
||||
arg: self.clone(),
|
||||
mul,
|
||||
add,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
storage,
|
||||
shape: shape.clone(),
|
||||
stride: shape.stride_contiguous(),
|
||||
op: Some(Op::Affine {
|
||||
arg: self.clone(),
|
||||
mul,
|
||||
add,
|
||||
}),
|
||||
op,
|
||||
is_variable: false,
|
||||
};
|
||||
Ok(Self(Arc::new(tensor_)))
|
||||
@ -307,12 +328,17 @@ impl Tensor {
|
||||
self.stride(),
|
||||
rhs.stride(),
|
||||
)?;
|
||||
let op = if self.track_op() || rhs.track_op() {
|
||||
Some(Op::Matmul(self.clone(), rhs.clone()))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
storage,
|
||||
shape: c_shape,
|
||||
stride: c_stride,
|
||||
op: Some(Op::Matmul(self.clone(), rhs.clone())),
|
||||
op,
|
||||
is_variable: false,
|
||||
};
|
||||
Ok(Self(Arc::new(tensor_)))
|
||||
|
Reference in New Issue
Block a user