Do not ignore errors when cloning the storage.

This commit is contained in:
laurent
2023-06-22 16:29:18 +01:00
parent 2f7a072250
commit 7d9a8ff3f9
4 changed files with 27 additions and 4 deletions

View File

@ -107,13 +107,20 @@ impl CudaDevice {
} }
} }
#[derive(Debug, Clone)] #[derive(Debug)]
pub enum CudaStorage { pub enum CudaStorage {
F32(CudaSlice<f32>), F32(CudaSlice<f32>),
F64(CudaSlice<f64>), F64(CudaSlice<f64>),
} }
impl CudaStorage { impl CudaStorage {
pub fn try_clone(&self) -> Result<Self> {
match self {
Self::F32(slice) => Ok(Self::F32(slice.try_clone()?)),
Self::F64(slice) => Ok(Self::F64(slice.try_clone()?)),
}
}
pub fn dtype(&self) -> DType { pub fn dtype(&self) -> DType {
match self { match self {
Self::F32(_) => DType::F32, Self::F32(_) => DType::F32,

View File

@ -34,10 +34,14 @@ impl CudaDevice {
} }
} }
#[derive(Debug, Clone)] #[derive(Debug)]
pub struct CudaStorage; pub struct CudaStorage;
impl CudaStorage { impl CudaStorage {
pub fn try_clone(&self) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
pub fn dtype(&self) -> DType { pub fn dtype(&self) -> DType {
fail!() fail!()
} }

View File

@ -1,12 +1,24 @@
use crate::{op, CpuStorage, CudaStorage, DType, Device, Error, Result, Shape}; use crate::{op, CpuStorage, CudaStorage, DType, Device, Error, Result, Shape};
#[derive(Debug, Clone)] // We do not want to implement Clone on Storage as cloning may fail because of
// out of memory. Instead try_clone should be used.
#[derive(Debug)]
pub enum Storage { pub enum Storage {
Cpu(CpuStorage), Cpu(CpuStorage),
Cuda(CudaStorage), Cuda(CudaStorage),
} }
impl Storage { impl Storage {
pub fn try_clone(&self) -> Result<Self> {
match self {
Self::Cpu(storage) => Ok(Self::Cpu(storage.clone())),
Self::Cuda(storage) => {
let storage = storage.try_clone()?;
Ok(Self::Cuda(storage))
}
}
}
pub fn device(&self) -> Device { pub fn device(&self) -> Device {
match self { match self {
Self::Cpu(_) => Device::Cpu, Self::Cpu(_) => Device::Cpu,

View File

@ -441,7 +441,7 @@ impl Tensor {
(stride[n - 2], stride[n - 1]) = (stride[n - 1], stride[n - 2]); (stride[n - 2], stride[n - 1]) = (stride[n - 1], stride[n - 2]);
let tensor_ = Tensor_ { let tensor_ = Tensor_ {
id: TensorId::new(), id: TensorId::new(),
storage: self.storage.clone(), storage: self.storage.try_clone()?,
shape: Shape::from(dims), shape: Shape::from(dims),
stride, stride,
// TODO The op should have a backward // TODO The op should have a backward