mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Do not ignore errors when cloning the storage.
This commit is contained in:
@ -107,13 +107,20 @@ impl CudaDevice {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug)]
|
||||||
pub enum CudaStorage {
|
pub enum CudaStorage {
|
||||||
F32(CudaSlice<f32>),
|
F32(CudaSlice<f32>),
|
||||||
F64(CudaSlice<f64>),
|
F64(CudaSlice<f64>),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CudaStorage {
|
impl CudaStorage {
|
||||||
|
pub fn try_clone(&self) -> Result<Self> {
|
||||||
|
match self {
|
||||||
|
Self::F32(slice) => Ok(Self::F32(slice.try_clone()?)),
|
||||||
|
Self::F64(slice) => Ok(Self::F64(slice.try_clone()?)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn dtype(&self) -> DType {
|
pub fn dtype(&self) -> DType {
|
||||||
match self {
|
match self {
|
||||||
Self::F32(_) => DType::F32,
|
Self::F32(_) => DType::F32,
|
||||||
|
@ -34,10 +34,14 @@ impl CudaDevice {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug)]
|
||||||
pub struct CudaStorage;
|
pub struct CudaStorage;
|
||||||
|
|
||||||
impl CudaStorage {
|
impl CudaStorage {
|
||||||
|
pub fn try_clone(&self) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
|
}
|
||||||
|
|
||||||
pub fn dtype(&self) -> DType {
|
pub fn dtype(&self) -> DType {
|
||||||
fail!()
|
fail!()
|
||||||
}
|
}
|
||||||
|
@ -1,12 +1,24 @@
|
|||||||
use crate::{op, CpuStorage, CudaStorage, DType, Device, Error, Result, Shape};
|
use crate::{op, CpuStorage, CudaStorage, DType, Device, Error, Result, Shape};
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
// We do not want to implement Clone on Storage as cloning may fail because of
|
||||||
|
// out of memory. Instead try_clone should be used.
|
||||||
|
#[derive(Debug)]
|
||||||
pub enum Storage {
|
pub enum Storage {
|
||||||
Cpu(CpuStorage),
|
Cpu(CpuStorage),
|
||||||
Cuda(CudaStorage),
|
Cuda(CudaStorage),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Storage {
|
impl Storage {
|
||||||
|
pub fn try_clone(&self) -> Result<Self> {
|
||||||
|
match self {
|
||||||
|
Self::Cpu(storage) => Ok(Self::Cpu(storage.clone())),
|
||||||
|
Self::Cuda(storage) => {
|
||||||
|
let storage = storage.try_clone()?;
|
||||||
|
Ok(Self::Cuda(storage))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn device(&self) -> Device {
|
pub fn device(&self) -> Device {
|
||||||
match self {
|
match self {
|
||||||
Self::Cpu(_) => Device::Cpu,
|
Self::Cpu(_) => Device::Cpu,
|
||||||
|
@ -441,7 +441,7 @@ impl Tensor {
|
|||||||
(stride[n - 2], stride[n - 1]) = (stride[n - 1], stride[n - 2]);
|
(stride[n - 2], stride[n - 1]) = (stride[n - 1], stride[n - 2]);
|
||||||
let tensor_ = Tensor_ {
|
let tensor_ = Tensor_ {
|
||||||
id: TensorId::new(),
|
id: TensorId::new(),
|
||||||
storage: self.storage.clone(),
|
storage: self.storage.try_clone()?,
|
||||||
shape: Shape::from(dims),
|
shape: Shape::from(dims),
|
||||||
stride,
|
stride,
|
||||||
// TODO The op should have a backward
|
// TODO The op should have a backward
|
||||||
|
Reference in New Issue
Block a user