Add some specific errors rather than panicking.

This commit is contained in:
laurent
2023-06-22 07:51:53 +01:00
parent db35b31050
commit fc26bab3ed
3 changed files with 14 additions and 8 deletions

View File

@ -1,4 +1,4 @@
use crate::{CpuStorage, DType, Result, Shape};
use crate::{CpuStorage, DType, Error, Result, Shape};
use cudarc::driver::{CudaSlice, LaunchAsync, LaunchConfig};
pub type CudaError = cudarc::driver::DriverError;
@ -92,7 +92,7 @@ impl CudaStorage {
match self {
Self::F32(arg) => {
if !shape.is_contiguous(stride) {
todo!("affine is only implemented for the contiguous case")
return Err(Error::RequiresContiguous { op: "affine" });
}
let dev = arg.device();
let module_name = "affine_f32";

View File

@ -1,5 +1,5 @@
#![allow(dead_code)]
use crate::{CpuStorage, DType, Result, Shape};
use crate::{CpuStorage, DType, Error, Result, Shape};
pub type CudaError = std::io::Error;
@ -14,7 +14,7 @@ macro_rules! fail {
impl CudaDevice {
pub(crate) fn new(_: usize) -> Result<Self> {
fail!()
Err(Error::NotCompiledWithCudaSupport)
}
pub(crate) fn ordinal(&self) -> usize {
@ -22,11 +22,11 @@ impl CudaDevice {
}
pub(crate) fn zeros_impl(&self, _shape: &Shape, _dtype: DType) -> Result<CudaStorage> {
fail!()
Err(Error::NotCompiledWithCudaSupport)
}
pub(crate) fn cuda_from_cpu_storage(&self, _: &CpuStorage) -> Result<CudaStorage> {
fail!()
Err(Error::NotCompiledWithCudaSupport)
}
}
@ -43,10 +43,10 @@ impl CudaStorage {
}
pub(crate) fn to_cpu_storage(&self) -> Result<CpuStorage> {
fail!()
Err(Error::NotCompiledWithCudaSupport)
}
pub(crate) fn affine_impl(&self, _: &Shape, _: &[usize], _: f64, _: f64) -> Result<Self> {
fail!()
Err(Error::NotCompiledWithCudaSupport)
}
}

View File

@ -6,6 +6,12 @@ pub enum Error {
#[error("unexpected dtype, expected: {expected:?}, got: {got:?}")]
UnexpectedDType { expected: DType, got: DType },
#[error("{op} only supports contiguous tensors")]
RequiresContiguous { op: &'static str },
#[error("the candle crate has not been built with cuda support")]
NotCompiledWithCudaSupport,
#[error("shape mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")]
ShapeMismatchBinaryOp {
lhs: Shape,