diff --git a/candle-core/src/error.rs b/candle-core/src/error.rs index 3f860fb0..f9e69122 100644 --- a/candle-core/src/error.rs +++ b/candle-core/src/error.rs @@ -174,6 +174,7 @@ pub enum Error { #[error("unsupported safetensor dtype {0:?}")] UnsupportedSafeTensorDtype(safetensors::Dtype), + /// Arbitrary errors wrapping. #[error(transparent)] Wrapped(Box), @@ -182,6 +183,10 @@ pub enum Error { inner: Box, backtrace: Box, }, + + /// User generated error message, typically created via `bail!`. + #[error("{0}")] + Msg(String), } pub type Result = std::result::Result; @@ -207,12 +212,12 @@ impl Error { #[macro_export] macro_rules! bail { ($msg:literal $(,)?) => { - return Err($crate::Error::Wrapped(format!($msg).into()).bt()) + return Err($crate::Error::Msg(format!($msg).into()).bt()) }; ($err:expr $(,)?) => { - return Err($crate::Error::Wrapped(format!($err).into()).bt()) + return Err($crate::Error::Msg(format!($err).into()).bt()) }; ($fmt:expr, $($arg:tt)*) => { - return Err($crate::Error::Wrapped(format!($fmt, $($arg)*).into()).bt()) + return Err($crate::Error::Msg(format!($fmt, $($arg)*).into()).bt()) }; } diff --git a/candle-examples/examples/custom-ops/main.rs b/candle-examples/examples/custom-ops/main.rs index 0de78b72..63bcd83a 100644 --- a/candle-examples/examples/custom-ops/main.rs +++ b/candle-examples/examples/custom-ops/main.rs @@ -14,7 +14,7 @@ use clap::Parser; use candle::backend::BackendStorage; use candle::cpu_backend; -use candle::{CpuStorage, CustomOp1, DType, Device, Error, Layout, Result, Shape, Tensor}; +use candle::{CpuStorage, CustomOp1, DType, Device, Layout, Result, Shape, Tensor}; #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] @@ -37,7 +37,7 @@ impl CustomOp1 for LayerNorm { let (dim1, dim2) = layout.shape().dims2()?; let slice = storage.as_slice::()?; let src = match layout.contiguous_offsets() { - None => Err(Error::Wrapped("input has to be contiguous".into()))?, + None => candle::bail!("input has to be contiguous"), Some((o1, o2)) => &slice[o1..o2], }; let mut dst = Vec::with_capacity(dim1 * dim2); @@ -65,7 +65,7 @@ impl CustomOp1 for LayerNorm { let dev = storage.device().clone(); let slice = storage.as_cuda_slice::()?; let slice = match layout.contiguous_offsets() { - None => Err(Error::Wrapped("input has to be contiguous".into()))?, + None => candle::bail!("input has to be contiguous"), Some((o1, o2)) => slice.slice(o1..o2), }; let elem_count = layout.shape().elem_count(); diff --git a/candle-flash-attn/src/lib.rs b/candle-flash-attn/src/lib.rs index c2dec7d7..0123543b 100644 --- a/candle-flash-attn/src/lib.rs +++ b/candle-flash-attn/src/lib.rs @@ -3,7 +3,7 @@ mod ffi; use candle::backend::BackendStorage; use candle::cuda_backend::cudarc::driver::DevicePtr; use candle::cuda_backend::WrapErr; -use candle::{CpuStorage, Error, Layout, Result, Shape, Tensor}; +use candle::{CpuStorage, Layout, Result, Shape, Tensor}; use half::f16; pub struct FlashHdim32Sm80 { @@ -29,7 +29,7 @@ impl candle::CustomOp3 for FlashHdim32Sm80 { _: &CpuStorage, _: &Layout, ) -> Result<(CpuStorage, Shape)> { - Err(Error::Wrapped("no cpu support for flash-attn".into())) + candle::bail!("no cpu support for flash-attn") } fn cuda_fwd(