use crate::{DType, Layout}; /// cudarc related errors #[derive(thiserror::Error, Debug)] pub enum CudaError { #[error(transparent)] Cuda(#[from] cudarc::driver::DriverError), #[error(transparent)] Compiler(#[from] cudarc::nvrtc::CompileError), #[error(transparent)] Cublas(#[from] cudarc::cublas::result::CublasError), #[error(transparent)] Curand(#[from] cudarc::curand::result::CurandError), #[error("missing kernel '{module_name}'")] MissingKernel { module_name: String }, #[error("unsupported dtype {dtype:?} for {op}")] UnsupportedDtype { dtype: DType, op: &'static str }, #[error("internal error '{0}'")] InternalError(&'static str), #[error("matmul is only supported for contiguous tensors lstride: {lhs_stride:?} rstride: {rhs_stride:?} mnk: {mnk:?}")] MatMulNonContiguous { lhs_stride: Layout, rhs_stride: Layout, mnk: (usize, usize, usize), }, #[error("{msg}, expected: {expected:?}, got: {got:?}")] UnexpectedDType { msg: &'static str, expected: DType, got: DType, }, #[error("{cuda} when loading {module_name}")] Load { cuda: cudarc::driver::DriverError, module_name: String, }, } impl From for crate::Error { fn from(val: CudaError) -> Self { crate::Error::Cuda(Box::new(val)).bt() } } pub trait WrapErr { fn w(self) -> std::result::Result; } impl> WrapErr for std::result::Result { fn w(self) -> std::result::Result { self.map_err(|e| crate::Error::Cuda(Box::new(e.into())).bt()) } }