From 47f9c48e7c907713559909263ee054567d8afdb0 Mon Sep 17 00:00:00 2001 From: laurent Date: Sat, 24 Jun 2023 07:03:21 +0100 Subject: [PATCH] Avoid duplicating the storage by refcounting it. --- src/tensor.rs | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/src/tensor.rs b/src/tensor.rs index 31c6be4b..508dee49 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -17,7 +17,7 @@ impl TensorId { pub struct Tensor_ { id: TensorId, - storage: Storage, + storage: Arc, shape: Shape, // The strides are given in number of elements and not in bytes. stride: Vec, @@ -25,6 +25,9 @@ pub struct Tensor_ { is_variable: bool, } +// Tensors are refcounted so that cloning is cheap when building the op graph. +// Storages are also refcounted independently so that its possible to avoid +// copying the storage for operations that only modify the shape or stride. #[derive(Clone)] pub struct Tensor(Arc); @@ -104,7 +107,7 @@ fn from_storage(storage: Storage, shape: Shape, op: Option, is_variable: boo let stride = shape.stride_contiguous(); let tensor_ = Tensor_ { id: TensorId::new(), - storage, + storage: Arc::new(storage), shape, stride, op, @@ -274,7 +277,7 @@ impl Tensor { let data = S::cpu_storage_as_slice(cpu_storage)?; Ok::<_, Error>(data[0]) }; - match &self.storage { + match self.storage.as_ref() { Storage::Cpu(cpu_storage) => from_cpu_storage(cpu_storage), Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?), } @@ -394,7 +397,7 @@ impl Tensor { /// into account so the size of the resulting buffer might be larger than the /// tensor number of elements. pub fn storage_data(&self) -> Result> { - match &self.storage { + match self.storage.as_ref() { Storage::Cpu(cpu_storage) => { let slice = S::cpu_storage_as_slice(cpu_storage)?; Ok(std::borrow::Cow::Borrowed(slice)) @@ -415,7 +418,7 @@ impl Tensor { shape: self.shape().clone(), }); } - match &self.storage { + match self.storage.as_ref() { Storage::Cpu(cpu_storage) => { let data = S::cpu_storage_as_slice(cpu_storage)?; Ok(self.strided_index().map(|i| data[i]).collect()) @@ -442,7 +445,7 @@ impl Tensor { assert!(src_index.next().is_none()); Ok(rows) }; - match &self.storage { + match self.storage.as_ref() { Storage::Cpu(storage) => from_cpu_storage(storage), Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?), } @@ -465,7 +468,7 @@ impl Tensor { assert!(src_index.next().is_none()); Ok(top_rows) }; - match &self.storage { + match self.storage.as_ref() { Storage::Cpu(storage) => from_cpu_storage(storage), Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?), } @@ -539,7 +542,7 @@ impl Tensor { }; let tensor_ = Tensor_ { id: TensorId::new(), - storage: self.storage.try_clone()?, + storage: self.storage.clone(), shape: Shape::from(dims), stride, op, @@ -557,7 +560,7 @@ impl Tensor { pub fn copy(&self) -> Result { let tensor_ = Tensor_ { id: TensorId::new(), - storage: self.storage.try_clone()?, + storage: Arc::new(self.storage.try_clone()?), shape: self.shape.clone(), stride: self.stride.clone(), op: self.op.clone(), @@ -566,14 +569,12 @@ impl Tensor { Ok(Tensor(Arc::new(tensor_))) } - // TODO: Currently this duplicates the storage, the PyTorch version would share the storage, - // maybe we should do the same? /// Returns a new tensor detached from the current graph, gradient are not propagated through /// this new node. pub fn detach(&self) -> Result { let tensor_ = Tensor_ { id: TensorId::new(), - storage: self.storage.try_clone()?, + storage: self.storage.clone(), shape: self.shape.clone(), stride: self.stride.clone(), op: None, @@ -587,7 +588,7 @@ impl Tensor { if self.device().same_id(device) { Ok(self.clone()) } else { - let storage = match (&self.storage, device) { + let storage = match (self.storage.as_ref(), device) { (Storage::Cpu(storage), Device::Cuda(cuda)) => { Storage::Cuda(cuda.cuda_from_cpu_storage(storage)?) } @@ -607,7 +608,7 @@ impl Tensor { }; let tensor_ = Tensor_ { id: TensorId::new(), - storage, + storage: Arc::new(storage), shape: self.shape.clone(), stride: self.stride.clone(), op,