From 7d9a8ff3f97f837859ae31a86999ed14b80217f6 Mon Sep 17 00:00:00 2001 From: laurent Date: Thu, 22 Jun 2023 16:29:18 +0100 Subject: [PATCH] Do not ignore errors when cloning the storage. --- src/cuda_backend.rs | 9 ++++++++- src/dummy_cuda_backend.rs | 6 +++++- src/storage.rs | 14 +++++++++++++- src/tensor.rs | 2 +- 4 files changed, 27 insertions(+), 4 deletions(-) diff --git a/src/cuda_backend.rs b/src/cuda_backend.rs index 741ab0fd..c34cd3ea 100644 --- a/src/cuda_backend.rs +++ b/src/cuda_backend.rs @@ -107,13 +107,20 @@ impl CudaDevice { } } -#[derive(Debug, Clone)] +#[derive(Debug)] pub enum CudaStorage { F32(CudaSlice), F64(CudaSlice), } impl CudaStorage { + pub fn try_clone(&self) -> Result { + match self { + Self::F32(slice) => Ok(Self::F32(slice.try_clone()?)), + Self::F64(slice) => Ok(Self::F64(slice.try_clone()?)), + } + } + pub fn dtype(&self) -> DType { match self { Self::F32(_) => DType::F32, diff --git a/src/dummy_cuda_backend.rs b/src/dummy_cuda_backend.rs index 63b55bac..939247cd 100644 --- a/src/dummy_cuda_backend.rs +++ b/src/dummy_cuda_backend.rs @@ -34,10 +34,14 @@ impl CudaDevice { } } -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct CudaStorage; impl CudaStorage { + pub fn try_clone(&self) -> Result { + Err(Error::NotCompiledWithCudaSupport) + } + pub fn dtype(&self) -> DType { fail!() } diff --git a/src/storage.rs b/src/storage.rs index 00746089..b4c2c272 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -1,12 +1,24 @@ use crate::{op, CpuStorage, CudaStorage, DType, Device, Error, Result, Shape}; -#[derive(Debug, Clone)] +// We do not want to implement Clone on Storage as cloning may fail because of +// out of memory. Instead try_clone should be used. +#[derive(Debug)] pub enum Storage { Cpu(CpuStorage), Cuda(CudaStorage), } impl Storage { + pub fn try_clone(&self) -> Result { + match self { + Self::Cpu(storage) => Ok(Self::Cpu(storage.clone())), + Self::Cuda(storage) => { + let storage = storage.try_clone()?; + Ok(Self::Cuda(storage)) + } + } + } + pub fn device(&self) -> Device { match self { Self::Cpu(_) => Device::Cpu, diff --git a/src/tensor.rs b/src/tensor.rs index 72b756a0..f1e60efc 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -441,7 +441,7 @@ impl Tensor { (stride[n - 2], stride[n - 1]) = (stride[n - 1], stride[n - 2]); let tensor_ = Tensor_ { id: TensorId::new(), - storage: self.storage.clone(), + storage: self.storage.try_clone()?, shape: Shape::from(dims), stride, // TODO The op should have a backward