Lazy detach. (#1242)

This commit is contained in:
Laurent Mazare
2023-11-02 08:33:48 +01:00
committed by GitHub
parent 6c990a33ea
commit fbd69f952c
2 changed files with 20 additions and 10 deletions

View File

@ -920,6 +920,10 @@ impl BackpropOp {
};
Self(op)
}
pub(crate) fn is_none(&self) -> bool {
self.0.is_none()
}
}
impl std::ops::Deref for BackpropOp {

View File

@ -1807,7 +1807,12 @@ impl Tensor {
/// Returns a new tensor detached from the current graph, gradient are not propagated through
/// this new node. The storage of this tensor is shared with the initial tensor.
///
/// If the tensor is already detached from the computation graph, the same tensor is returned.
pub fn detach(&self) -> Result<Tensor> {
if self.op.is_none() && !self.is_variable {
Ok(self.clone())
} else {
let tensor_ = Tensor_ {
id: TensorId::new(),
storage: self.storage.clone(),
@ -1819,6 +1824,7 @@ impl Tensor {
};
Ok(Tensor(Arc::new(tensor_)))
}
}
/// If the target device is the same as the tensor device, only a shallow copy is performed.
pub fn to_device(&self, device: &Device) -> Result<Tensor> {