diff --git a/src/cuda_backend.rs b/src/cuda_backend.rs index d12db972..7858e542 100644 --- a/src/cuda_backend.rs +++ b/src/cuda_backend.rs @@ -1,4 +1,4 @@ -use crate::{CpuStorage, DType, Result, Shape}; +use crate::{CpuStorage, DType, Error, Result, Shape}; use cudarc::driver::{CudaSlice, LaunchAsync, LaunchConfig}; pub type CudaError = cudarc::driver::DriverError; @@ -92,7 +92,7 @@ impl CudaStorage { match self { Self::F32(arg) => { if !shape.is_contiguous(stride) { - todo!("affine is only implemented for the contiguous case") + return Err(Error::RequiresContiguous { op: "affine" }); } let dev = arg.device(); let module_name = "affine_f32"; diff --git a/src/dummy_cuda_backend.rs b/src/dummy_cuda_backend.rs index f555327f..85b5f598 100644 --- a/src/dummy_cuda_backend.rs +++ b/src/dummy_cuda_backend.rs @@ -1,5 +1,5 @@ #![allow(dead_code)] -use crate::{CpuStorage, DType, Result, Shape}; +use crate::{CpuStorage, DType, Error, Result, Shape}; pub type CudaError = std::io::Error; @@ -14,7 +14,7 @@ macro_rules! fail { impl CudaDevice { pub(crate) fn new(_: usize) -> Result { - fail!() + Err(Error::NotCompiledWithCudaSupport) } pub(crate) fn ordinal(&self) -> usize { @@ -22,11 +22,11 @@ impl CudaDevice { } pub(crate) fn zeros_impl(&self, _shape: &Shape, _dtype: DType) -> Result { - fail!() + Err(Error::NotCompiledWithCudaSupport) } pub(crate) fn cuda_from_cpu_storage(&self, _: &CpuStorage) -> Result { - fail!() + Err(Error::NotCompiledWithCudaSupport) } } @@ -43,10 +43,10 @@ impl CudaStorage { } pub(crate) fn to_cpu_storage(&self) -> Result { - fail!() + Err(Error::NotCompiledWithCudaSupport) } pub(crate) fn affine_impl(&self, _: &Shape, _: &[usize], _: f64, _: f64) -> Result { - fail!() + Err(Error::NotCompiledWithCudaSupport) } } diff --git a/src/error.rs b/src/error.rs index 3f142960..27201cb4 100644 --- a/src/error.rs +++ b/src/error.rs @@ -6,6 +6,12 @@ pub enum Error { #[error("unexpected dtype, expected: {expected:?}, got: {got:?}")] UnexpectedDType { expected: DType, got: DType }, + #[error("{op} only supports contiguous tensors")] + RequiresContiguous { op: &'static str }, + + #[error("the candle crate has not been built with cuda support")] + NotCompiledWithCudaSupport, + #[error("shape mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")] ShapeMismatchBinaryOp { lhs: Shape,