Add the detach method.

This commit is contained in:
laurent
2023-06-23 09:19:23 +01:00
parent 3e7cb18d7f
commit bf9e1d1c23

View File

@ -524,6 +524,22 @@ impl Tensor {
Ok(Tensor(Arc::new(tensor_))) Ok(Tensor(Arc::new(tensor_)))
} }
// TODO: Currently this duplicates the storage, the PyTorch version would share the storage,
// maybe we should do the same?
/// Returns a new tensor detached from the current graph, gradient are not propagated through
/// this new node.
pub fn detach(&self) -> Result<Tensor> {
let tensor_ = Tensor_ {
id: TensorId::new(),
storage: self.storage.try_clone()?,
shape: self.shape.clone(),
stride: self.stride.clone(),
op: None,
is_variable: false,
};
Ok(Tensor(Arc::new(tensor_)))
}
/// If the target device is the same as the tensor device, only a shallow copy is performed. /// If the target device is the same as the tensor device, only a shallow copy is performed.
pub fn to_device(&self, device: &Device) -> Result<Tensor> { pub fn to_device(&self, device: &Device) -> Result<Tensor> {
if self.device().same_id(device) { if self.device().same_id(device) {