From bf9e1d1c23d7813fc91f6a1a26c69ec81bda904d Mon Sep 17 00:00:00 2001 From: laurent Date: Fri, 23 Jun 2023 09:19:23 +0100 Subject: [PATCH] Add the detach method. --- src/tensor.rs | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/tensor.rs b/src/tensor.rs index 11d69dec..53665ced 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -524,6 +524,22 @@ impl Tensor { Ok(Tensor(Arc::new(tensor_))) } + // TODO: Currently this duplicates the storage, the PyTorch version would share the storage, + // maybe we should do the same? + /// Returns a new tensor detached from the current graph, gradient are not propagated through + /// this new node. + pub fn detach(&self) -> Result { + let tensor_ = Tensor_ { + id: TensorId::new(), + storage: self.storage.try_clone()?, + shape: self.shape.clone(), + stride: self.stride.clone(), + op: None, + is_variable: false, + }; + Ok(Tensor(Arc::new(tensor_))) + } + /// If the target device is the same as the tensor device, only a shallow copy is performed. pub fn to_device(&self, device: &Device) -> Result { if self.device().same_id(device) {