mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Split the cuda error file. (#2003)
This commit is contained in:
62
candle-core/src/cuda_backend/error.rs
Normal file
62
candle-core/src/cuda_backend/error.rs
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
use crate::{DType, Layout};
|
||||||
|
|
||||||
|
/// cudarc related errors
|
||||||
|
#[derive(thiserror::Error, Debug)]
|
||||||
|
pub enum CudaError {
|
||||||
|
#[error(transparent)]
|
||||||
|
Cuda(#[from] cudarc::driver::DriverError),
|
||||||
|
|
||||||
|
#[error(transparent)]
|
||||||
|
Compiler(#[from] cudarc::nvrtc::CompileError),
|
||||||
|
|
||||||
|
#[error(transparent)]
|
||||||
|
Cublas(#[from] cudarc::cublas::result::CublasError),
|
||||||
|
|
||||||
|
#[error(transparent)]
|
||||||
|
Curand(#[from] cudarc::curand::result::CurandError),
|
||||||
|
|
||||||
|
#[error("missing kernel '{module_name}'")]
|
||||||
|
MissingKernel { module_name: String },
|
||||||
|
|
||||||
|
#[error("unsupported dtype {dtype:?} for {op}")]
|
||||||
|
UnsupportedDtype { dtype: DType, op: &'static str },
|
||||||
|
|
||||||
|
#[error("internal error '{0}'")]
|
||||||
|
InternalError(&'static str),
|
||||||
|
|
||||||
|
#[error("matmul is only supported for contiguous tensors lstride: {lhs_stride:?} rstride: {rhs_stride:?} mnk: {mnk:?}")]
|
||||||
|
MatMulNonContiguous {
|
||||||
|
lhs_stride: Layout,
|
||||||
|
rhs_stride: Layout,
|
||||||
|
mnk: (usize, usize, usize),
|
||||||
|
},
|
||||||
|
|
||||||
|
#[error("{msg}, expected: {expected:?}, got: {got:?}")]
|
||||||
|
UnexpectedDType {
|
||||||
|
msg: &'static str,
|
||||||
|
expected: DType,
|
||||||
|
got: DType,
|
||||||
|
},
|
||||||
|
|
||||||
|
#[error("{cuda} when loading {module_name}")]
|
||||||
|
Load {
|
||||||
|
cuda: cudarc::driver::DriverError,
|
||||||
|
module_name: String,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<CudaError> for crate::Error {
|
||||||
|
fn from(val: CudaError) -> Self {
|
||||||
|
crate::Error::Cuda(Box::new(val)).bt()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait WrapErr<O> {
|
||||||
|
fn w(self) -> std::result::Result<O, crate::Error>;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<O, E: Into<CudaError>> WrapErr<O> for std::result::Result<O, E> {
|
||||||
|
fn w(self) -> std::result::Result<O, crate::Error> {
|
||||||
|
self.map_err(|e| crate::Error::Cuda(Box::new(e.into())).bt())
|
||||||
|
}
|
||||||
|
}
|
@ -9,13 +9,14 @@ use cudarc::driver::{
|
|||||||
};
|
};
|
||||||
use half::{bf16, f16};
|
use half::{bf16, f16};
|
||||||
|
|
||||||
mod device;
|
|
||||||
pub use device::{CudaDevice, DeviceId};
|
|
||||||
mod utils;
|
|
||||||
pub use utils::{Map1, Map1Any, Map2, Map2Any, Map2InPlace, S};
|
|
||||||
|
|
||||||
#[cfg(feature = "cudnn")]
|
#[cfg(feature = "cudnn")]
|
||||||
pub mod cudnn;
|
pub mod cudnn;
|
||||||
|
mod device;
|
||||||
|
mod error;
|
||||||
|
mod utils;
|
||||||
|
pub use device::{CudaDevice, DeviceId};
|
||||||
|
pub use error::{CudaError, WrapErr};
|
||||||
|
pub use utils::{Map1, Map1Any, Map2, Map2Any, Map2InPlace, S};
|
||||||
|
|
||||||
enum SlicePtrOrNull<T> {
|
enum SlicePtrOrNull<T> {
|
||||||
Ptr(CudaSlice<T>),
|
Ptr(CudaSlice<T>),
|
||||||
@ -42,67 +43,6 @@ impl SlicePtrOrNull<usize> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// cudarc related errors
|
|
||||||
#[derive(thiserror::Error, Debug)]
|
|
||||||
pub enum CudaError {
|
|
||||||
#[error(transparent)]
|
|
||||||
Cuda(#[from] cudarc::driver::DriverError),
|
|
||||||
|
|
||||||
#[error(transparent)]
|
|
||||||
Compiler(#[from] cudarc::nvrtc::CompileError),
|
|
||||||
|
|
||||||
#[error(transparent)]
|
|
||||||
Cublas(#[from] cudarc::cublas::result::CublasError),
|
|
||||||
|
|
||||||
#[error(transparent)]
|
|
||||||
Curand(#[from] cudarc::curand::result::CurandError),
|
|
||||||
|
|
||||||
#[error("missing kernel '{module_name}'")]
|
|
||||||
MissingKernel { module_name: String },
|
|
||||||
|
|
||||||
#[error("unsupported dtype {dtype:?} for {op}")]
|
|
||||||
UnsupportedDtype { dtype: DType, op: &'static str },
|
|
||||||
|
|
||||||
#[error("internal error '{0}'")]
|
|
||||||
InternalError(&'static str),
|
|
||||||
|
|
||||||
#[error("matmul is only supported for contiguous tensors lstride: {lhs_stride:?} rstride: {rhs_stride:?} mnk: {mnk:?}")]
|
|
||||||
MatMulNonContiguous {
|
|
||||||
lhs_stride: Layout,
|
|
||||||
rhs_stride: Layout,
|
|
||||||
mnk: (usize, usize, usize),
|
|
||||||
},
|
|
||||||
|
|
||||||
#[error("{msg}, expected: {expected:?}, got: {got:?}")]
|
|
||||||
UnexpectedDType {
|
|
||||||
msg: &'static str,
|
|
||||||
expected: DType,
|
|
||||||
got: DType,
|
|
||||||
},
|
|
||||||
|
|
||||||
#[error("{cuda} when loading {module_name}")]
|
|
||||||
Load {
|
|
||||||
cuda: cudarc::driver::DriverError,
|
|
||||||
module_name: String,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<CudaError> for crate::Error {
|
|
||||||
fn from(val: CudaError) -> Self {
|
|
||||||
crate::Error::Cuda(Box::new(val)).bt()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub trait WrapErr<O> {
|
|
||||||
fn w(self) -> std::result::Result<O, crate::Error>;
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<O, E: Into<CudaError>> WrapErr<O> for std::result::Result<O, E> {
|
|
||||||
fn w(self) -> std::result::Result<O, crate::Error> {
|
|
||||||
self.map_err(|e| crate::Error::Cuda(Box::new(e.into())).bt())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub enum CudaStorageSlice {
|
pub enum CudaStorageSlice {
|
||||||
U8(CudaSlice<u8>),
|
U8(CudaSlice<u8>),
|
||||||
|
Reference in New Issue
Block a user