mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Do not ignore errors when cloning the storage.
This commit is contained in:
@ -107,13 +107,20 @@ impl CudaDevice {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Debug)]
|
||||
pub enum CudaStorage {
|
||||
F32(CudaSlice<f32>),
|
||||
F64(CudaSlice<f64>),
|
||||
}
|
||||
|
||||
impl CudaStorage {
|
||||
pub fn try_clone(&self) -> Result<Self> {
|
||||
match self {
|
||||
Self::F32(slice) => Ok(Self::F32(slice.try_clone()?)),
|
||||
Self::F64(slice) => Ok(Self::F64(slice.try_clone()?)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn dtype(&self) -> DType {
|
||||
match self {
|
||||
Self::F32(_) => DType::F32,
|
||||
|
@ -34,10 +34,14 @@ impl CudaDevice {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Debug)]
|
||||
pub struct CudaStorage;
|
||||
|
||||
impl CudaStorage {
|
||||
pub fn try_clone(&self) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
pub fn dtype(&self) -> DType {
|
||||
fail!()
|
||||
}
|
||||
|
@ -1,12 +1,24 @@
|
||||
use crate::{op, CpuStorage, CudaStorage, DType, Device, Error, Result, Shape};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
// We do not want to implement Clone on Storage as cloning may fail because of
|
||||
// out of memory. Instead try_clone should be used.
|
||||
#[derive(Debug)]
|
||||
pub enum Storage {
|
||||
Cpu(CpuStorage),
|
||||
Cuda(CudaStorage),
|
||||
}
|
||||
|
||||
impl Storage {
|
||||
pub fn try_clone(&self) -> Result<Self> {
|
||||
match self {
|
||||
Self::Cpu(storage) => Ok(Self::Cpu(storage.clone())),
|
||||
Self::Cuda(storage) => {
|
||||
let storage = storage.try_clone()?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn device(&self) -> Device {
|
||||
match self {
|
||||
Self::Cpu(_) => Device::Cpu,
|
||||
|
@ -441,7 +441,7 @@ impl Tensor {
|
||||
(stride[n - 2], stride[n - 1]) = (stride[n - 1], stride[n - 2]);
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
storage: self.storage.clone(),
|
||||
storage: self.storage.try_clone()?,
|
||||
shape: Shape::from(dims),
|
||||
stride,
|
||||
// TODO The op should have a backward
|
||||
|
Reference in New Issue
Block a user