mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Use bail rather than wrapping a string where possible. (#249)
* Use bail rather than wrapping a string where possible. * Revert the cuda default bit.
This commit is contained in:
@ -174,6 +174,7 @@ pub enum Error {
|
|||||||
#[error("unsupported safetensor dtype {0:?}")]
|
#[error("unsupported safetensor dtype {0:?}")]
|
||||||
UnsupportedSafeTensorDtype(safetensors::Dtype),
|
UnsupportedSafeTensorDtype(safetensors::Dtype),
|
||||||
|
|
||||||
|
/// Arbitrary errors wrapping.
|
||||||
#[error(transparent)]
|
#[error(transparent)]
|
||||||
Wrapped(Box<dyn std::error::Error + Send + Sync>),
|
Wrapped(Box<dyn std::error::Error + Send + Sync>),
|
||||||
|
|
||||||
@ -182,6 +183,10 @@ pub enum Error {
|
|||||||
inner: Box<Self>,
|
inner: Box<Self>,
|
||||||
backtrace: Box<std::backtrace::Backtrace>,
|
backtrace: Box<std::backtrace::Backtrace>,
|
||||||
},
|
},
|
||||||
|
|
||||||
|
/// User generated error message, typically created via `bail!`.
|
||||||
|
#[error("{0}")]
|
||||||
|
Msg(String),
|
||||||
}
|
}
|
||||||
|
|
||||||
pub type Result<T> = std::result::Result<T, Error>;
|
pub type Result<T> = std::result::Result<T, Error>;
|
||||||
@ -207,12 +212,12 @@ impl Error {
|
|||||||
#[macro_export]
|
#[macro_export]
|
||||||
macro_rules! bail {
|
macro_rules! bail {
|
||||||
($msg:literal $(,)?) => {
|
($msg:literal $(,)?) => {
|
||||||
return Err($crate::Error::Wrapped(format!($msg).into()).bt())
|
return Err($crate::Error::Msg(format!($msg).into()).bt())
|
||||||
};
|
};
|
||||||
($err:expr $(,)?) => {
|
($err:expr $(,)?) => {
|
||||||
return Err($crate::Error::Wrapped(format!($err).into()).bt())
|
return Err($crate::Error::Msg(format!($err).into()).bt())
|
||||||
};
|
};
|
||||||
($fmt:expr, $($arg:tt)*) => {
|
($fmt:expr, $($arg:tt)*) => {
|
||||||
return Err($crate::Error::Wrapped(format!($fmt, $($arg)*).into()).bt())
|
return Err($crate::Error::Msg(format!($fmt, $($arg)*).into()).bt())
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
@ -14,7 +14,7 @@ use clap::Parser;
|
|||||||
|
|
||||||
use candle::backend::BackendStorage;
|
use candle::backend::BackendStorage;
|
||||||
use candle::cpu_backend;
|
use candle::cpu_backend;
|
||||||
use candle::{CpuStorage, CustomOp1, DType, Device, Error, Layout, Result, Shape, Tensor};
|
use candle::{CpuStorage, CustomOp1, DType, Device, Layout, Result, Shape, Tensor};
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
#[command(author, version, about, long_about = None)]
|
#[command(author, version, about, long_about = None)]
|
||||||
@ -37,7 +37,7 @@ impl CustomOp1 for LayerNorm {
|
|||||||
let (dim1, dim2) = layout.shape().dims2()?;
|
let (dim1, dim2) = layout.shape().dims2()?;
|
||||||
let slice = storage.as_slice::<f32>()?;
|
let slice = storage.as_slice::<f32>()?;
|
||||||
let src = match layout.contiguous_offsets() {
|
let src = match layout.contiguous_offsets() {
|
||||||
None => Err(Error::Wrapped("input has to be contiguous".into()))?,
|
None => candle::bail!("input has to be contiguous"),
|
||||||
Some((o1, o2)) => &slice[o1..o2],
|
Some((o1, o2)) => &slice[o1..o2],
|
||||||
};
|
};
|
||||||
let mut dst = Vec::with_capacity(dim1 * dim2);
|
let mut dst = Vec::with_capacity(dim1 * dim2);
|
||||||
@ -65,7 +65,7 @@ impl CustomOp1 for LayerNorm {
|
|||||||
let dev = storage.device().clone();
|
let dev = storage.device().clone();
|
||||||
let slice = storage.as_cuda_slice::<f32>()?;
|
let slice = storage.as_cuda_slice::<f32>()?;
|
||||||
let slice = match layout.contiguous_offsets() {
|
let slice = match layout.contiguous_offsets() {
|
||||||
None => Err(Error::Wrapped("input has to be contiguous".into()))?,
|
None => candle::bail!("input has to be contiguous"),
|
||||||
Some((o1, o2)) => slice.slice(o1..o2),
|
Some((o1, o2)) => slice.slice(o1..o2),
|
||||||
};
|
};
|
||||||
let elem_count = layout.shape().elem_count();
|
let elem_count = layout.shape().elem_count();
|
||||||
|
@ -3,7 +3,7 @@ mod ffi;
|
|||||||
use candle::backend::BackendStorage;
|
use candle::backend::BackendStorage;
|
||||||
use candle::cuda_backend::cudarc::driver::DevicePtr;
|
use candle::cuda_backend::cudarc::driver::DevicePtr;
|
||||||
use candle::cuda_backend::WrapErr;
|
use candle::cuda_backend::WrapErr;
|
||||||
use candle::{CpuStorage, Error, Layout, Result, Shape, Tensor};
|
use candle::{CpuStorage, Layout, Result, Shape, Tensor};
|
||||||
use half::f16;
|
use half::f16;
|
||||||
|
|
||||||
pub struct FlashHdim32Sm80 {
|
pub struct FlashHdim32Sm80 {
|
||||||
@ -29,7 +29,7 @@ impl candle::CustomOp3 for FlashHdim32Sm80 {
|
|||||||
_: &CpuStorage,
|
_: &CpuStorage,
|
||||||
_: &Layout,
|
_: &Layout,
|
||||||
) -> Result<(CpuStorage, Shape)> {
|
) -> Result<(CpuStorage, Shape)> {
|
||||||
Err(Error::Wrapped("no cpu support for flash-attn".into()))
|
candle::bail!("no cpu support for flash-attn")
|
||||||
}
|
}
|
||||||
|
|
||||||
fn cuda_fwd(
|
fn cuda_fwd(
|
||||||
|
Reference in New Issue
Block a user