Do not ignore errors when cloning the storage.

This commit is contained in:
laurent
2023-06-22 16:29:18 +01:00
parent 2f7a072250
commit 7d9a8ff3f9
4 changed files with 27 additions and 4 deletions

View File

@ -107,13 +107,20 @@ impl CudaDevice {
}
}
#[derive(Debug, Clone)]
#[derive(Debug)]
pub enum CudaStorage {
F32(CudaSlice<f32>),
F64(CudaSlice<f64>),
}
impl CudaStorage {
pub fn try_clone(&self) -> Result<Self> {
match self {
Self::F32(slice) => Ok(Self::F32(slice.try_clone()?)),
Self::F64(slice) => Ok(Self::F64(slice.try_clone()?)),
}
}
pub fn dtype(&self) -> DType {
match self {
Self::F32(_) => DType::F32,

View File

@ -34,10 +34,14 @@ impl CudaDevice {
}
}
#[derive(Debug, Clone)]
#[derive(Debug)]
pub struct CudaStorage;
impl CudaStorage {
pub fn try_clone(&self) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
pub fn dtype(&self) -> DType {
fail!()
}

View File

@ -1,12 +1,24 @@
use crate::{op, CpuStorage, CudaStorage, DType, Device, Error, Result, Shape};
#[derive(Debug, Clone)]
// We do not want to implement Clone on Storage as cloning may fail because of
// out of memory. Instead try_clone should be used.
#[derive(Debug)]
pub enum Storage {
Cpu(CpuStorage),
Cuda(CudaStorage),
}
impl Storage {
pub fn try_clone(&self) -> Result<Self> {
match self {
Self::Cpu(storage) => Ok(Self::Cpu(storage.clone())),
Self::Cuda(storage) => {
let storage = storage.try_clone()?;
Ok(Self::Cuda(storage))
}
}
}
pub fn device(&self) -> Device {
match self {
Self::Cpu(_) => Device::Cpu,

View File

@ -441,7 +441,7 @@ impl Tensor {
(stride[n - 2], stride[n - 1]) = (stride[n - 1], stride[n - 2]);
let tensor_ = Tensor_ {
id: TensorId::new(),
storage: self.storage.clone(),
storage: self.storage.try_clone()?,
shape: Shape::from(dims),
stride,
// TODO The op should have a backward