mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 03:28:50 +00:00
Cleanup the main crate error and add a couple dedicated ones (#142)
* Cosmetic cleanups to the error enum. * More error cleanup. * Proper error handling rather than panicing. * Add some conv1d dedicated error.
This commit is contained in:
@ -1,8 +1,17 @@
|
||||
use crate::{DType, DeviceLocation, Shape};
|
||||
use crate::{DType, DeviceLocation, Layout, Shape};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MatMulUnexpectedStriding {
|
||||
pub lhs_l: Layout,
|
||||
pub rhs_l: Layout,
|
||||
pub bmnk: (usize, usize, usize, usize),
|
||||
pub msg: &'static str,
|
||||
}
|
||||
|
||||
/// Main library error type.
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum Error {
|
||||
// === DType Errors ===
|
||||
#[error("{msg}, expected: {expected:?}, got: {got:?}")]
|
||||
UnexpectedDType {
|
||||
msg: &'static str,
|
||||
@ -10,6 +19,32 @@ pub enum Error {
|
||||
got: DType,
|
||||
},
|
||||
|
||||
#[error("dtype mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")]
|
||||
DTypeMismatchBinaryOp {
|
||||
lhs: DType,
|
||||
rhs: DType,
|
||||
op: &'static str,
|
||||
},
|
||||
|
||||
#[error("unsupported dtype {0:?} for op {1}")]
|
||||
UnsupportedDTypeForOp(DType, &'static str),
|
||||
|
||||
// === Dimension Index Errors ===
|
||||
#[error("{op}: dimension index {dim} out of range for {shape:?}")]
|
||||
DimOutOfRange {
|
||||
shape: Shape,
|
||||
dim: i32,
|
||||
op: &'static str,
|
||||
},
|
||||
|
||||
// === Shape Errors ===
|
||||
#[error("unexpected rank, expected: {expected}, got: {got} ({shape:?})")]
|
||||
UnexpectedNumberOfDims {
|
||||
expected: usize,
|
||||
got: usize,
|
||||
shape: Shape,
|
||||
},
|
||||
|
||||
#[error("{msg}, expected: {expected:?}, got: {got:?}")]
|
||||
UnexpectedShape {
|
||||
msg: String,
|
||||
@ -17,40 +52,6 @@ pub enum Error {
|
||||
got: Shape,
|
||||
},
|
||||
|
||||
#[error("{op}: dimension index {dim} out of range for {shape:?}")]
|
||||
DimOutOfRange {
|
||||
shape: Shape,
|
||||
dim: usize,
|
||||
op: &'static str,
|
||||
},
|
||||
|
||||
#[error("invalid args for narrow: {shape:?}, dim: {dim}, start: {start}, len:{len}")]
|
||||
NarrowInvalidArgs {
|
||||
shape: Shape,
|
||||
dim: usize,
|
||||
start: usize,
|
||||
len: usize,
|
||||
},
|
||||
|
||||
#[error("{op} only supports contiguous tensors")]
|
||||
RequiresContiguous { op: &'static str },
|
||||
|
||||
#[error("{op} expects at least one tensor")]
|
||||
OpRequiresAtLeastOneTensor { op: &'static str },
|
||||
|
||||
#[error("backward is not supported for {op}")]
|
||||
BackwardNotSupported { op: &'static str },
|
||||
|
||||
#[error("{op} invalid index {index} with vocab {vocab_size}")]
|
||||
InvalidIndex {
|
||||
op: &'static str,
|
||||
index: usize,
|
||||
vocab_size: usize,
|
||||
},
|
||||
|
||||
#[error("the candle crate has not been built with cuda support")]
|
||||
NotCompiledWithCudaSupport,
|
||||
|
||||
#[error(
|
||||
"Shape mismatch, got buffer of size {buffer_size} which is compatible with shape {shape:?}"
|
||||
)]
|
||||
@ -71,6 +72,7 @@ pub enum Error {
|
||||
nth_shape: Shape,
|
||||
},
|
||||
|
||||
// === Device Errors ===
|
||||
#[error("device mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")]
|
||||
DeviceMismatchBinaryOp {
|
||||
lhs: DeviceLocation,
|
||||
@ -78,27 +80,56 @@ pub enum Error {
|
||||
op: &'static str,
|
||||
},
|
||||
|
||||
#[error("dtype mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")]
|
||||
DTypeMismatchBinaryOp {
|
||||
lhs: DType,
|
||||
rhs: DType,
|
||||
op: &'static str,
|
||||
},
|
||||
|
||||
#[error("unexpected rank, expected: {expected}, got: {got} ({shape:?})")]
|
||||
UnexpectedNumberOfDims {
|
||||
expected: usize,
|
||||
got: usize,
|
||||
// === Op Specific Errors ===
|
||||
#[error("narrow invalid args {msg}: {shape:?}, dim: {dim}, start: {start}, len:{len}")]
|
||||
NarrowInvalidArgs {
|
||||
shape: Shape,
|
||||
dim: usize,
|
||||
start: usize,
|
||||
len: usize,
|
||||
msg: &'static str,
|
||||
},
|
||||
|
||||
// TODO this is temporary when we support arbitrary matmul
|
||||
#[error("temporary error where matmul doesn't support arbitrary striding {lhs_stride:?} x {rhs_stride:?}")]
|
||||
UnexpectedStriding {
|
||||
lhs_stride: Vec<usize>,
|
||||
rhs_stride: Vec<usize>,
|
||||
#[error("conv1d invalid args {msg}: inp: {inp_shape:?}, k: {k_shape:?}, pad: {padding}, stride: {stride}")]
|
||||
Conv1dInvalidArgs {
|
||||
inp_shape: Shape,
|
||||
k_shape: Shape,
|
||||
padding: usize,
|
||||
stride: usize,
|
||||
msg: &'static str,
|
||||
},
|
||||
|
||||
#[error("{op} invalid index {index} with vocab {vocab_size}")]
|
||||
InvalidIndex {
|
||||
op: &'static str,
|
||||
index: usize,
|
||||
vocab_size: usize,
|
||||
},
|
||||
|
||||
#[error("cannot broadcast {src_shape:?} to {dst_shape:?}")]
|
||||
BroadcastIncompatibleShapes { src_shape: Shape, dst_shape: Shape },
|
||||
|
||||
// Box indirection to avoid large variant.
|
||||
#[error("{0:?}")]
|
||||
MatMulUnexpectedStriding(Box<MatMulUnexpectedStriding>),
|
||||
|
||||
#[error("{op} only supports contiguous tensors")]
|
||||
RequiresContiguous { op: &'static str },
|
||||
|
||||
#[error("{op} expects at least one tensor")]
|
||||
OpRequiresAtLeastOneTensor { op: &'static str },
|
||||
|
||||
#[error("backward is not supported for {op}")]
|
||||
BackwardNotSupported { op: &'static str },
|
||||
|
||||
// === Other Errors ===
|
||||
#[error("the candle crate has not been built with cuda support")]
|
||||
NotCompiledWithCudaSupport,
|
||||
|
||||
#[error("cannot find tensor {path}")]
|
||||
CannotFindTensor { path: String },
|
||||
|
||||
// === Wrapped Errors ===
|
||||
#[error(transparent)]
|
||||
Cuda(Box<dyn std::error::Error + Send + Sync>),
|
||||
|
||||
@ -126,22 +157,6 @@ pub enum Error {
|
||||
|
||||
#[error("unsupported safetensor dtype {0:?}")]
|
||||
UnsupportedSafeTensorDtype(safetensors::Dtype),
|
||||
|
||||
#[error("unsupported dtype {0:?} for op {1}")]
|
||||
UnsupportedDTypeForOp(DType, &'static str),
|
||||
|
||||
#[error("cannot broadcast {src_shape:?} to {dst_shape:?}")]
|
||||
BroadcastIncompatibleShapes { src_shape: Shape, dst_shape: Shape },
|
||||
|
||||
#[error("matmul is only supported for contiguous tensors lstride: {lhs_stride:?} rstride: {rhs_stride:?} mnk: {mnk:?}")]
|
||||
MatMulNonContiguous {
|
||||
lhs_stride: Vec<usize>,
|
||||
rhs_stride: Vec<usize>,
|
||||
mnk: (usize, usize, usize),
|
||||
},
|
||||
|
||||
#[error("cannot find tensor {path}")]
|
||||
CannotFindTensor { path: String },
|
||||
}
|
||||
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
|
Reference in New Issue
Block a user