mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
Only keep track of the graph when needed.
This commit is contained in:
@ -49,12 +49,17 @@ macro_rules! unary_op {
|
|||||||
let storage = self
|
let storage = self
|
||||||
.storage
|
.storage
|
||||||
.unary_impl::<crate::op::$op_name>(self.shape(), self.stride())?;
|
.unary_impl::<crate::op::$op_name>(self.shape(), self.stride())?;
|
||||||
|
let op = if self.track_op() {
|
||||||
|
Some(Op::$op_name(self.clone()))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
let tensor_ = Tensor_ {
|
let tensor_ = Tensor_ {
|
||||||
id: TensorId::new(),
|
id: TensorId::new(),
|
||||||
storage,
|
storage,
|
||||||
shape: shape.clone(),
|
shape: shape.clone(),
|
||||||
stride: shape.stride_contiguous(),
|
stride: shape.stride_contiguous(),
|
||||||
op: Some(Op::$op_name(self.clone())),
|
op,
|
||||||
is_variable: false,
|
is_variable: false,
|
||||||
};
|
};
|
||||||
Ok(Self(Arc::new(tensor_)))
|
Ok(Self(Arc::new(tensor_)))
|
||||||
@ -72,12 +77,17 @@ macro_rules! binary_op {
|
|||||||
self.stride(),
|
self.stride(),
|
||||||
rhs.stride(),
|
rhs.stride(),
|
||||||
)?;
|
)?;
|
||||||
|
let op = if self.track_op() || rhs.track_op() {
|
||||||
|
Some(Op::$op_name(self.clone(), rhs.clone()))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
let tensor_ = Tensor_ {
|
let tensor_ = Tensor_ {
|
||||||
id: TensorId::new(),
|
id: TensorId::new(),
|
||||||
storage,
|
storage,
|
||||||
shape: shape.clone(),
|
shape: shape.clone(),
|
||||||
stride: shape.stride_contiguous(),
|
stride: shape.stride_contiguous(),
|
||||||
op: Some(Op::$op_name(self.clone(), rhs.clone())),
|
op,
|
||||||
is_variable: false,
|
is_variable: false,
|
||||||
};
|
};
|
||||||
Ok(Self(Arc::new(tensor_)))
|
Ok(Self(Arc::new(tensor_)))
|
||||||
@ -214,6 +224,12 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns true if the computation graph should track this op, that is if it is
|
||||||
|
/// a variable or if it has some variable as dependencies.
|
||||||
|
pub(crate) fn track_op(&self) -> bool {
|
||||||
|
self.is_variable || self.op.is_some()
|
||||||
|
}
|
||||||
|
|
||||||
// TODO: Also make an inplace version or a pre-allocated? This could be tricky
|
// TODO: Also make an inplace version or a pre-allocated? This could be tricky
|
||||||
// if this can create cycles in the compute graph.
|
// if this can create cycles in the compute graph.
|
||||||
binary_op!(add, Add);
|
binary_op!(add, Add);
|
||||||
@ -247,16 +263,21 @@ impl Tensor {
|
|||||||
let storage = self
|
let storage = self
|
||||||
.storage
|
.storage
|
||||||
.affine_impl(self.shape(), self.stride(), mul, add)?;
|
.affine_impl(self.shape(), self.stride(), mul, add)?;
|
||||||
|
let op = if self.track_op() {
|
||||||
|
Some(Op::Affine {
|
||||||
|
arg: self.clone(),
|
||||||
|
mul,
|
||||||
|
add,
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
let tensor_ = Tensor_ {
|
let tensor_ = Tensor_ {
|
||||||
id: TensorId::new(),
|
id: TensorId::new(),
|
||||||
storage,
|
storage,
|
||||||
shape: shape.clone(),
|
shape: shape.clone(),
|
||||||
stride: shape.stride_contiguous(),
|
stride: shape.stride_contiguous(),
|
||||||
op: Some(Op::Affine {
|
op,
|
||||||
arg: self.clone(),
|
|
||||||
mul,
|
|
||||||
add,
|
|
||||||
}),
|
|
||||||
is_variable: false,
|
is_variable: false,
|
||||||
};
|
};
|
||||||
Ok(Self(Arc::new(tensor_)))
|
Ok(Self(Arc::new(tensor_)))
|
||||||
@ -307,12 +328,17 @@ impl Tensor {
|
|||||||
self.stride(),
|
self.stride(),
|
||||||
rhs.stride(),
|
rhs.stride(),
|
||||||
)?;
|
)?;
|
||||||
|
let op = if self.track_op() || rhs.track_op() {
|
||||||
|
Some(Op::Matmul(self.clone(), rhs.clone()))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
let tensor_ = Tensor_ {
|
let tensor_ = Tensor_ {
|
||||||
id: TensorId::new(),
|
id: TensorId::new(),
|
||||||
storage,
|
storage,
|
||||||
shape: c_shape,
|
shape: c_shape,
|
||||||
stride: c_stride,
|
stride: c_stride,
|
||||||
op: Some(Op::Matmul(self.clone(), rhs.clone())),
|
op,
|
||||||
is_variable: false,
|
is_variable: false,
|
||||||
};
|
};
|
||||||
Ok(Self(Arc::new(tensor_)))
|
Ok(Self(Arc::new(tensor_)))
|
||||||
|
Reference in New Issue
Block a user