Use bail rather than wrapping a string where possible. (#249)

* Use bail rather than wrapping a string where possible.

* Revert the cuda default bit.
This commit is contained in:
Laurent Mazare
2023-07-26 15:42:46 +01:00
committed by GitHub
parent f052ba76cb
commit 1235aa2536
3 changed files with 13 additions and 8 deletions

View File

@ -174,6 +174,7 @@ pub enum Error {
#[error("unsupported safetensor dtype {0:?}")] #[error("unsupported safetensor dtype {0:?}")]
UnsupportedSafeTensorDtype(safetensors::Dtype), UnsupportedSafeTensorDtype(safetensors::Dtype),
/// Arbitrary errors wrapping.
#[error(transparent)] #[error(transparent)]
Wrapped(Box<dyn std::error::Error + Send + Sync>), Wrapped(Box<dyn std::error::Error + Send + Sync>),
@ -182,6 +183,10 @@ pub enum Error {
inner: Box<Self>, inner: Box<Self>,
backtrace: Box<std::backtrace::Backtrace>, backtrace: Box<std::backtrace::Backtrace>,
}, },
/// User generated error message, typically created via `bail!`.
#[error("{0}")]
Msg(String),
} }
pub type Result<T> = std::result::Result<T, Error>; pub type Result<T> = std::result::Result<T, Error>;
@ -207,12 +212,12 @@ impl Error {
#[macro_export] #[macro_export]
macro_rules! bail { macro_rules! bail {
($msg:literal $(,)?) => { ($msg:literal $(,)?) => {
return Err($crate::Error::Wrapped(format!($msg).into()).bt()) return Err($crate::Error::Msg(format!($msg).into()).bt())
}; };
($err:expr $(,)?) => { ($err:expr $(,)?) => {
return Err($crate::Error::Wrapped(format!($err).into()).bt()) return Err($crate::Error::Msg(format!($err).into()).bt())
}; };
($fmt:expr, $($arg:tt)*) => { ($fmt:expr, $($arg:tt)*) => {
return Err($crate::Error::Wrapped(format!($fmt, $($arg)*).into()).bt()) return Err($crate::Error::Msg(format!($fmt, $($arg)*).into()).bt())
}; };
} }

View File

@ -14,7 +14,7 @@ use clap::Parser;
use candle::backend::BackendStorage; use candle::backend::BackendStorage;
use candle::cpu_backend; use candle::cpu_backend;
use candle::{CpuStorage, CustomOp1, DType, Device, Error, Layout, Result, Shape, Tensor}; use candle::{CpuStorage, CustomOp1, DType, Device, Layout, Result, Shape, Tensor};
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)] #[command(author, version, about, long_about = None)]
@ -37,7 +37,7 @@ impl CustomOp1 for LayerNorm {
let (dim1, dim2) = layout.shape().dims2()?; let (dim1, dim2) = layout.shape().dims2()?;
let slice = storage.as_slice::<f32>()?; let slice = storage.as_slice::<f32>()?;
let src = match layout.contiguous_offsets() { let src = match layout.contiguous_offsets() {
None => Err(Error::Wrapped("input has to be contiguous".into()))?, None => candle::bail!("input has to be contiguous"),
Some((o1, o2)) => &slice[o1..o2], Some((o1, o2)) => &slice[o1..o2],
}; };
let mut dst = Vec::with_capacity(dim1 * dim2); let mut dst = Vec::with_capacity(dim1 * dim2);
@ -65,7 +65,7 @@ impl CustomOp1 for LayerNorm {
let dev = storage.device().clone(); let dev = storage.device().clone();
let slice = storage.as_cuda_slice::<f32>()?; let slice = storage.as_cuda_slice::<f32>()?;
let slice = match layout.contiguous_offsets() { let slice = match layout.contiguous_offsets() {
None => Err(Error::Wrapped("input has to be contiguous".into()))?, None => candle::bail!("input has to be contiguous"),
Some((o1, o2)) => slice.slice(o1..o2), Some((o1, o2)) => slice.slice(o1..o2),
}; };
let elem_count = layout.shape().elem_count(); let elem_count = layout.shape().elem_count();

View File

@ -3,7 +3,7 @@ mod ffi;
use candle::backend::BackendStorage; use candle::backend::BackendStorage;
use candle::cuda_backend::cudarc::driver::DevicePtr; use candle::cuda_backend::cudarc::driver::DevicePtr;
use candle::cuda_backend::WrapErr; use candle::cuda_backend::WrapErr;
use candle::{CpuStorage, Error, Layout, Result, Shape, Tensor}; use candle::{CpuStorage, Layout, Result, Shape, Tensor};
use half::f16; use half::f16;
pub struct FlashHdim32Sm80 { pub struct FlashHdim32Sm80 {
@ -29,7 +29,7 @@ impl candle::CustomOp3 for FlashHdim32Sm80 {
_: &CpuStorage, _: &CpuStorage,
_: &Layout, _: &Layout,
) -> Result<(CpuStorage, Shape)> { ) -> Result<(CpuStorage, Shape)> {
Err(Error::Wrapped("no cpu support for flash-attn".into())) candle::bail!("no cpu support for flash-attn")
} }
fn cuda_fwd( fn cuda_fwd(