From 8967c46563221c01db4fc6a920231a9ef0d6f7bc Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Thu, 4 Apr 2024 08:27:23 +0200 Subject: [PATCH] Split the cuda error file. (#2003) --- candle-core/src/cuda_backend/error.rs | 62 +++++++++++++++++++++++ candle-core/src/cuda_backend/mod.rs | 72 +++------------------------ 2 files changed, 68 insertions(+), 66 deletions(-) create mode 100644 candle-core/src/cuda_backend/error.rs diff --git a/candle-core/src/cuda_backend/error.rs b/candle-core/src/cuda_backend/error.rs new file mode 100644 index 00000000..bd6f8ac6 --- /dev/null +++ b/candle-core/src/cuda_backend/error.rs @@ -0,0 +1,62 @@ +use crate::{DType, Layout}; + +/// cudarc related errors +#[derive(thiserror::Error, Debug)] +pub enum CudaError { + #[error(transparent)] + Cuda(#[from] cudarc::driver::DriverError), + + #[error(transparent)] + Compiler(#[from] cudarc::nvrtc::CompileError), + + #[error(transparent)] + Cublas(#[from] cudarc::cublas::result::CublasError), + + #[error(transparent)] + Curand(#[from] cudarc::curand::result::CurandError), + + #[error("missing kernel '{module_name}'")] + MissingKernel { module_name: String }, + + #[error("unsupported dtype {dtype:?} for {op}")] + UnsupportedDtype { dtype: DType, op: &'static str }, + + #[error("internal error '{0}'")] + InternalError(&'static str), + + #[error("matmul is only supported for contiguous tensors lstride: {lhs_stride:?} rstride: {rhs_stride:?} mnk: {mnk:?}")] + MatMulNonContiguous { + lhs_stride: Layout, + rhs_stride: Layout, + mnk: (usize, usize, usize), + }, + + #[error("{msg}, expected: {expected:?}, got: {got:?}")] + UnexpectedDType { + msg: &'static str, + expected: DType, + got: DType, + }, + + #[error("{cuda} when loading {module_name}")] + Load { + cuda: cudarc::driver::DriverError, + module_name: String, + }, +} + +impl From for crate::Error { + fn from(val: CudaError) -> Self { + crate::Error::Cuda(Box::new(val)).bt() + } +} + +pub trait WrapErr { + fn w(self) -> std::result::Result; +} + +impl> WrapErr for std::result::Result { + fn w(self) -> std::result::Result { + self.map_err(|e| crate::Error::Cuda(Box::new(e.into())).bt()) + } +} diff --git a/candle-core/src/cuda_backend/mod.rs b/candle-core/src/cuda_backend/mod.rs index 6a9e73f8..6fecf7c7 100644 --- a/candle-core/src/cuda_backend/mod.rs +++ b/candle-core/src/cuda_backend/mod.rs @@ -9,13 +9,14 @@ use cudarc::driver::{ }; use half::{bf16, f16}; -mod device; -pub use device::{CudaDevice, DeviceId}; -mod utils; -pub use utils::{Map1, Map1Any, Map2, Map2Any, Map2InPlace, S}; - #[cfg(feature = "cudnn")] pub mod cudnn; +mod device; +mod error; +mod utils; +pub use device::{CudaDevice, DeviceId}; +pub use error::{CudaError, WrapErr}; +pub use utils::{Map1, Map1Any, Map2, Map2Any, Map2InPlace, S}; enum SlicePtrOrNull { Ptr(CudaSlice), @@ -42,67 +43,6 @@ impl SlicePtrOrNull { } } -/// cudarc related errors -#[derive(thiserror::Error, Debug)] -pub enum CudaError { - #[error(transparent)] - Cuda(#[from] cudarc::driver::DriverError), - - #[error(transparent)] - Compiler(#[from] cudarc::nvrtc::CompileError), - - #[error(transparent)] - Cublas(#[from] cudarc::cublas::result::CublasError), - - #[error(transparent)] - Curand(#[from] cudarc::curand::result::CurandError), - - #[error("missing kernel '{module_name}'")] - MissingKernel { module_name: String }, - - #[error("unsupported dtype {dtype:?} for {op}")] - UnsupportedDtype { dtype: DType, op: &'static str }, - - #[error("internal error '{0}'")] - InternalError(&'static str), - - #[error("matmul is only supported for contiguous tensors lstride: {lhs_stride:?} rstride: {rhs_stride:?} mnk: {mnk:?}")] - MatMulNonContiguous { - lhs_stride: Layout, - rhs_stride: Layout, - mnk: (usize, usize, usize), - }, - - #[error("{msg}, expected: {expected:?}, got: {got:?}")] - UnexpectedDType { - msg: &'static str, - expected: DType, - got: DType, - }, - - #[error("{cuda} when loading {module_name}")] - Load { - cuda: cudarc::driver::DriverError, - module_name: String, - }, -} - -impl From for crate::Error { - fn from(val: CudaError) -> Self { - crate::Error::Cuda(Box::new(val)).bt() - } -} - -pub trait WrapErr { - fn w(self) -> std::result::Result; -} - -impl> WrapErr for std::result::Result { - fn w(self) -> std::result::Result { - self.map_err(|e| crate::Error::Cuda(Box::new(e.into())).bt()) - } -} - #[derive(Debug)] pub enum CudaStorageSlice { U8(CudaSlice),