Add some specific errors rather than panicking.

This commit is contained in:
laurent
2023-06-22 07:51:53 +01:00
parent db35b31050
commit fc26bab3ed
3 changed files with 14 additions and 8 deletions

View File

@ -1,4 +1,4 @@
use crate::{CpuStorage, DType, Result, Shape}; use crate::{CpuStorage, DType, Error, Result, Shape};
use cudarc::driver::{CudaSlice, LaunchAsync, LaunchConfig}; use cudarc::driver::{CudaSlice, LaunchAsync, LaunchConfig};
pub type CudaError = cudarc::driver::DriverError; pub type CudaError = cudarc::driver::DriverError;
@ -92,7 +92,7 @@ impl CudaStorage {
match self { match self {
Self::F32(arg) => { Self::F32(arg) => {
if !shape.is_contiguous(stride) { if !shape.is_contiguous(stride) {
todo!("affine is only implemented for the contiguous case") return Err(Error::RequiresContiguous { op: "affine" });
} }
let dev = arg.device(); let dev = arg.device();
let module_name = "affine_f32"; let module_name = "affine_f32";

View File

@ -1,5 +1,5 @@
#![allow(dead_code)] #![allow(dead_code)]
use crate::{CpuStorage, DType, Result, Shape}; use crate::{CpuStorage, DType, Error, Result, Shape};
pub type CudaError = std::io::Error; pub type CudaError = std::io::Error;
@ -14,7 +14,7 @@ macro_rules! fail {
impl CudaDevice { impl CudaDevice {
pub(crate) fn new(_: usize) -> Result<Self> { pub(crate) fn new(_: usize) -> Result<Self> {
fail!() Err(Error::NotCompiledWithCudaSupport)
} }
pub(crate) fn ordinal(&self) -> usize { pub(crate) fn ordinal(&self) -> usize {
@ -22,11 +22,11 @@ impl CudaDevice {
} }
pub(crate) fn zeros_impl(&self, _shape: &Shape, _dtype: DType) -> Result<CudaStorage> { pub(crate) fn zeros_impl(&self, _shape: &Shape, _dtype: DType) -> Result<CudaStorage> {
fail!() Err(Error::NotCompiledWithCudaSupport)
} }
pub(crate) fn cuda_from_cpu_storage(&self, _: &CpuStorage) -> Result<CudaStorage> { pub(crate) fn cuda_from_cpu_storage(&self, _: &CpuStorage) -> Result<CudaStorage> {
fail!() Err(Error::NotCompiledWithCudaSupport)
} }
} }
@ -43,10 +43,10 @@ impl CudaStorage {
} }
pub(crate) fn to_cpu_storage(&self) -> Result<CpuStorage> { pub(crate) fn to_cpu_storage(&self) -> Result<CpuStorage> {
fail!() Err(Error::NotCompiledWithCudaSupport)
} }
pub(crate) fn affine_impl(&self, _: &Shape, _: &[usize], _: f64, _: f64) -> Result<Self> { pub(crate) fn affine_impl(&self, _: &Shape, _: &[usize], _: f64, _: f64) -> Result<Self> {
fail!() Err(Error::NotCompiledWithCudaSupport)
} }
} }

View File

@ -6,6 +6,12 @@ pub enum Error {
#[error("unexpected dtype, expected: {expected:?}, got: {got:?}")] #[error("unexpected dtype, expected: {expected:?}, got: {got:?}")]
UnexpectedDType { expected: DType, got: DType }, UnexpectedDType { expected: DType, got: DType },
#[error("{op} only supports contiguous tensors")]
RequiresContiguous { op: &'static str },
#[error("the candle crate has not been built with cuda support")]
NotCompiledWithCudaSupport,
#[error("shape mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")] #[error("shape mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")]
ShapeMismatchBinaryOp { ShapeMismatchBinaryOp {
lhs: Shape, lhs: Shape,