diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index e1168c2e..ce6d970e 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -920,6 +920,10 @@ impl BackpropOp { }; Self(op) } + + pub(crate) fn is_none(&self) -> bool { + self.0.is_none() + } } impl std::ops::Deref for BackpropOp { diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index adcdc59d..2a5d3635 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -1807,17 +1807,23 @@ impl Tensor { /// Returns a new tensor detached from the current graph, gradient are not propagated through /// this new node. The storage of this tensor is shared with the initial tensor. + /// + /// If the tensor is already detached from the computation graph, the same tensor is returned. pub fn detach(&self) -> Result { - let tensor_ = Tensor_ { - id: TensorId::new(), - storage: self.storage.clone(), - layout: self.layout.clone(), - op: BackpropOp::none(), - is_variable: false, - dtype: self.dtype, - device: self.device.clone(), - }; - Ok(Tensor(Arc::new(tensor_))) + if self.op.is_none() && !self.is_variable { + Ok(self.clone()) + } else { + let tensor_ = Tensor_ { + id: TensorId::new(), + storage: self.storage.clone(), + layout: self.layout.clone(), + op: BackpropOp::none(), + is_variable: false, + dtype: self.dtype, + device: self.device.clone(), + }; + Ok(Tensor(Arc::new(tensor_))) + } } /// If the target device is the same as the tensor device, only a shallow copy is performed.