Add a dummy module.

This commit is contained in:
laurent
2023-06-21 21:16:00 +01:00
parent 97d9142dee
commit 304a557d84
5 changed files with 57 additions and 4 deletions

View File

@ -22,5 +22,5 @@ rand = "0.8.5"
tokenizers = "0.13.3"
[features]
default = ["cuda"]
default = []
cuda = ["dep:cudarc"]

View File

@ -1,7 +1,7 @@
use crate::{CpuStorage, DType, Result, Shape};
use cudarc::driver::{CudaSlice, LaunchAsync, LaunchConfig};
pub(crate) type Error = cudarc::driver::DriverError;
pub type CudaError = cudarc::driver::DriverError;
#[derive(Debug, Clone)]
pub struct CudaDevice(std::sync::Arc<cudarc::driver::CudaDevice>);

46
src/dummy_cuda_backend.rs Normal file
View File

@ -0,0 +1,46 @@
#![allow(dead_code)]
use crate::{CpuStorage, DType, Result, Shape};
pub type CudaError = std::io::Error;
#[derive(Debug, Clone)]
pub struct CudaDevice;
impl CudaDevice {
pub(crate) fn new(_: usize) -> Result<Self> {
unimplemented!("cuda support hasn't been enabled")
}
pub(crate) fn ordinal(&self) -> usize {
unimplemented!("cuda support hasn't been enabled")
}
pub(crate) fn zeros_impl(&self, _shape: &Shape, _dtype: DType) -> Result<CudaStorage> {
unimplemented!("cuda support hasn't been enabled")
}
pub(crate) fn cuda_from_cpu_storage(&self, _: &CpuStorage) -> Result<CudaStorage> {
unimplemented!("cuda support hasn't been enabled")
}
}
#[derive(Debug, Clone)]
pub struct CudaStorage;
impl CudaStorage {
pub fn dtype(&self) -> DType {
unimplemented!()
}
pub fn device(&self) -> CudaDevice {
unimplemented!()
}
pub(crate) fn to_cpu_storage(&self) -> Result<CpuStorage> {
unimplemented!()
}
pub(crate) fn affine_impl(&self, _: &Shape, _: &[usize], _: f64, _: f64) -> Result<Self> {
unimplemented!()
}
}

View File

@ -35,7 +35,7 @@ pub enum Error {
},
#[error(transparent)]
Cudarc(#[from] crate::cuda_backend::Error),
Cuda(#[from] crate::CudaError),
}
pub type Result<T> = std::result::Result<T, Error>;

View File

@ -1,7 +1,9 @@
mod cpu_backend;
#[cfg(feature = "cuda")]
mod cuda_backend;
mod device;
mod dtype;
mod dummy_cuda_backend;
mod error;
mod op;
mod shape;
@ -10,7 +12,6 @@ mod strided_index;
mod tensor;
pub use cpu_backend::CpuStorage;
pub use cuda_backend::{CudaDevice, CudaStorage};
pub use device::{Device, DeviceLocation};
pub use dtype::{DType, WithDType};
pub use error::{Error, Result};
@ -18,3 +19,9 @@ pub use shape::Shape;
pub use storage::Storage;
use strided_index::StridedIndex;
pub use tensor::{Tensor, TensorId};
#[cfg(feature = "cuda")]
pub use cuda_backend::{CudaDevice, CudaError, CudaStorage};
#[cfg(not(feature = "cuda"))]
pub use dummy_cuda_backend::{CudaDevice, CudaError, CudaStorage};