diff --git a/README.md b/README.md index 08c68e9e..ca1f0194 100644 --- a/README.md +++ b/README.md @@ -101,3 +101,8 @@ can try adding the following at the top of your binary: ``` extern crate intel_mkl_src; ``` + +### How to know where an error comes from. + +You can set `RUST_BACKTRACE=1` to be provided with backtraces when a candle +error is generated. diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index e6e2e7a2..28b1f5b0 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -56,7 +56,8 @@ trait Map2 { lhs: v1.dtype(), rhs: v2.dtype(), op: Self::OP, - }), + } + .bt()), } } } @@ -168,11 +169,12 @@ impl<'a> Map1 for Embedding<'a> { for index in self.ids_l.strided_index() { let index = self.ids[index].try_into()?; if index >= self.vocab_size { - return Err(Error::InvalidIndex { + Err(Error::InvalidIndex { index, vocab_size: self.vocab_size, op: "take", - }); + } + .bt())? } else { let hidden_size = self.hidden_size; values.extend(&vs[hidden_size * index..hidden_size * (index + 1)]); @@ -273,6 +275,7 @@ impl MatMul { bmnk: self.0, msg, })) + .bt() } } @@ -483,7 +486,7 @@ impl Map2 for MatMul { } } } - dtype => Err(Error::UnsupportedDTypeForOp(dtype, "matmul"))?, + dtype => Err(Error::UnsupportedDTypeForOp(dtype, "matmul").bt())?, } Ok(dst) } @@ -748,8 +751,8 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| elu(v, alpha)); Ok(Self::F64(data)) } - Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "elu")), - Self::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "elu")), + Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "elu").bt()), + Self::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "elu").bt()), } } @@ -814,7 +817,8 @@ impl BackendStorage for CpuStorage { lhs: self.dtype(), rhs: rhs.dtype(), op: B::NAME, - }) + } + .bt()) } } } @@ -833,7 +837,8 @@ impl BackendStorage for CpuStorage { lhs: self.dtype(), rhs: dst.dtype(), op: "copy_strided", - }); + } + .bt()); } } Ok(()) @@ -923,7 +928,7 @@ impl BackendDevice for CpuDevice { let mut rng = rand::thread_rng(); match dtype { DType::U8 | DType::U32 | DType::BF16 | DType::F16 => { - Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal")) + Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal").bt()) } DType::F32 => { let mut data = Vec::new(); @@ -953,7 +958,7 @@ impl BackendDevice for CpuDevice { let mut rng = rand::thread_rng(); match dtype { DType::U8 | DType::U32 | DType::BF16 | DType::F16 => { - Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal")) + Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal").bt()) } DType::F32 => { let mut data = Vec::new(); diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index 73bc9e34..74a3cf30 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -58,7 +58,7 @@ pub enum CudaError { impl From for crate::Error { fn from(val: CudaError) -> Self { - crate::Error::Cuda(Box::new(val)) + crate::Error::Cuda(Box::new(val)).bt() } } diff --git a/candle-core/src/dtype.rs b/candle-core/src/dtype.rs index 3a3f59f5..59802c04 100644 --- a/candle-core/src/dtype.rs +++ b/candle-core/src/dtype.rs @@ -92,7 +92,8 @@ macro_rules! with_dtype { expected: DType::$dtype, got: s.dtype(), msg: "unexpected dtype", - }), + } + .bt()), } } @@ -103,7 +104,8 @@ macro_rules! with_dtype { expected: DType::$dtype, got: s.dtype(), msg: "unexpected dtype", - }), + } + .bt()), } } } diff --git a/candle-core/src/error.rs b/candle-core/src/error.rs index d2648e66..e354b239 100644 --- a/candle-core/src/error.rs +++ b/candle-core/src/error.rs @@ -170,6 +170,12 @@ pub enum Error { #[error(transparent)] Wrapped(Box), + + #[error("{inner}\n{backtrace}")] + WithBacktrace { + inner: Box, + backtrace: Box, + }, } pub type Result = std::result::Result; @@ -178,4 +184,16 @@ impl Error { pub fn wrap(err: impl std::error::Error + Send + Sync + 'static) -> Self { Self::Wrapped(Box::new(err)) } + + pub fn bt(self) -> Self { + let backtrace = std::backtrace::Backtrace::capture(); + match backtrace.status() { + std::backtrace::BacktraceStatus::Disabled + | std::backtrace::BacktraceStatus::Unsupported => self, + _ => Self::WithBacktrace { + inner: Box::new(self), + backtrace: Box::new(backtrace), + }, + } + } } diff --git a/candle-core/src/layout.rs b/candle-core/src/layout.rs index 79d40cfc..d92864aa 100644 --- a/candle-core/src/layout.rs +++ b/candle-core/src/layout.rs @@ -67,7 +67,8 @@ impl Layout { shape: self.shape().clone(), dim: dim as i32, op: "narrow", - })? + } + .bt())? } if start + len > dims[dim] { Err(Error::NarrowInvalidArgs { @@ -76,7 +77,8 @@ impl Layout { start, len, msg: "start + len > dim_len", - })? + } + .bt())? } let mut dims = dims.to_vec(); dims[dim] = len; @@ -90,11 +92,12 @@ impl Layout { pub(crate) fn transpose(&self, dim1: usize, dim2: usize) -> Result { let rank = self.shape.rank(); if rank <= dim1 || rank <= dim2 { - return Err(Error::UnexpectedNumberOfDims { + Err(Error::UnexpectedNumberOfDims { expected: usize::max(dim1, dim2), got: rank, shape: self.shape().clone(), - }); + } + .bt())? } let mut stride = self.stride().to_vec(); let mut dims = self.shape().dims().to_vec(); @@ -110,10 +113,11 @@ impl Layout { pub fn broadcast_as>(&self, shape: S) -> Result { let shape = shape.into(); if shape.rank() < self.shape().rank() { - Err(Error::BroadcastIncompatibleShapes { + return Err(Error::BroadcastIncompatibleShapes { src_shape: self.shape().clone(), - dst_shape: shape.clone(), - })? + dst_shape: shape, + } + .bt()); } let added_dims = shape.rank() - self.shape().rank(); let mut stride = vec![0; added_dims]; @@ -127,7 +131,8 @@ impl Layout { return Err(Error::BroadcastIncompatibleShapes { src_shape: self.shape().clone(), dst_shape: shape, - }); + } + .bt()); } else { 0 }; diff --git a/candle-core/src/shape.rs b/candle-core/src/shape.rs index a267c068..d3f8db01 100644 --- a/candle-core/src/shape.rs +++ b/candle-core/src/shape.rs @@ -79,7 +79,8 @@ macro_rules! extract_dims { expected: $cnt, got: self.0.len(), shape: self.clone(), - }) + } + .bt()) } else { Ok($dims(&self.0)) } @@ -196,7 +197,8 @@ impl Dim for usize { shape: shape.clone(), dim: dim as i32, op, - })? + } + .bt())? } else { Ok(dim) } @@ -209,7 +211,8 @@ impl Dim for usize { shape: shape.clone(), dim: dim as i32, op, - })? + } + .bt())? } else { Ok(dim) } @@ -233,6 +236,7 @@ impl D { dim, op, } + .bt() } } @@ -267,14 +271,16 @@ pub trait Dims: Sized { shape: shape.clone(), dims: dims.clone(), op, - })? + } + .bt())? } if dim >= shape.rank() { Err(Error::DimOutOfRange { shape: shape.clone(), dim: dim as i32, op, - })? + } + .bt())? } } Ok(dims) diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs index 5f92172d..1531b212 100644 --- a/candle-core/src/storage.rs +++ b/candle-core/src/storage.rs @@ -38,7 +38,7 @@ impl Storage { let lhs = self.device().location(); let rhs = rhs.device().location(); if lhs != rhs { - Err(Error::DeviceMismatchBinaryOp { lhs, rhs, op }) + Err(Error::DeviceMismatchBinaryOp { lhs, rhs, op }.bt()) } else { Ok(()) } @@ -48,7 +48,7 @@ impl Storage { let lhs = self.dtype(); let rhs = rhs.dtype(); if lhs != rhs { - Err(Error::DTypeMismatchBinaryOp { lhs, rhs, op }) + Err(Error::DTypeMismatchBinaryOp { lhs, rhs, op }.bt()) } else { Ok(()) } @@ -153,7 +153,8 @@ impl Storage { lhs: lhs.device().location(), rhs: rhs.device().location(), op: B::NAME, - }) + } + .bt()) } } } @@ -180,7 +181,8 @@ impl Storage { lhs: lhs.device().location(), rhs: rhs.device().location(), op: "conv1d", - }), + } + .bt()), } } @@ -208,7 +210,8 @@ impl Storage { lhs: lhs.device().location(), rhs: rhs.device().location(), op: "where", - }), + } + .bt()), } } @@ -227,7 +230,8 @@ impl Storage { lhs: lhs.device().location(), rhs: rhs.device().location(), op: "embedding", - }), + } + .bt()), } } @@ -253,7 +257,8 @@ impl Storage { lhs: lhs.device().location(), rhs: rhs.device().location(), op: "matmul", - }), + } + .bt()), } } @@ -271,7 +276,8 @@ impl Storage { lhs: lhs.device().location(), rhs: rhs.device().location(), op: "copy", - }), + } + .bt()), } } } diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index c8353a70..c048790c 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -281,7 +281,7 @@ impl Tensor { let n: usize = shape.elem_count(); let buffer_size: usize = array.shape()?.elem_count(); if buffer_size != n { - return Err(Error::ShapeMismatch { buffer_size, shape }); + return Err(Error::ShapeMismatch { buffer_size, shape }.bt()); } let storage = device.storage(array)?; Ok(from_storage(storage, shape, None, is_variable)) @@ -336,7 +336,7 @@ impl Tensor { let shape = shape.into(); let buffer_size = data.len(); if buffer_size != shape.elem_count() { - return Err(Error::ShapeMismatch { buffer_size, shape }); + return Err(Error::ShapeMismatch { buffer_size, shape }.bt()); } let storage = device.storage_owned(data)?; Ok(from_storage(storage, shape, None, is_variable)) @@ -398,7 +398,8 @@ impl Tensor { lhs: self.shape().clone(), rhs: rhs.shape().clone(), op, - })? + } + .bt())? } } Ok(Shape::from(bcast_dims)) @@ -412,7 +413,8 @@ impl Tensor { lhs: lhs.clone(), rhs: rhs.clone(), op, - }) + } + .bt()) } else { Ok(lhs) } @@ -450,11 +452,12 @@ impl Tensor { /// dimensions, an error is returned instead. pub fn to_scalar(&self) -> Result { if self.rank() != 0 { - return Err(Error::UnexpectedNumberOfDims { + Err(Error::UnexpectedNumberOfDims { expected: 0, got: self.rank(), shape: self.shape().clone(), - }); + } + .bt())? } let from_cpu_storage = |cpu_storage: &crate::CpuStorage| { let data = S::cpu_storage_as_slice(cpu_storage)?; @@ -508,7 +511,8 @@ impl Tensor { shape: self.shape().clone(), dim: dim as i32, op, - })? + } + .bt())? } else { Ok(()) } @@ -526,7 +530,8 @@ impl Tensor { start, len, msg: "start + len > dim_len", - })? + } + .bt())? } if start == 0 && dims[dim] == len { Ok(self.clone()) @@ -666,7 +671,8 @@ impl Tensor { padding, stride, msg: "input rank is not 2 or 3", - })?, + } + .bt())?, }; if c_in != c_in_k { Err(Error::Conv1dInvalidArgs { @@ -675,7 +681,8 @@ impl Tensor { padding, stride, msg: "the number of in-channels on the input doesn't match the kernel size", - })? + } + .bt())? } let params = crate::conv::ParamsConv1D { b_size, @@ -722,7 +729,8 @@ impl Tensor { lhs: self.shape().clone(), rhs: rhs.shape().clone(), op: "matmul", - })? + } + .bt())? } let m = a_dims[dim - 2]; @@ -738,7 +746,8 @@ impl Tensor { lhs: self.shape().clone(), rhs: rhs.shape().clone(), op: "matmul", - })? + } + .bt())? } let storage = self.storage().matmul( @@ -801,13 +810,14 @@ impl Tensor { /// ``` pub fn embedding(ids: &Self, rhs: &Self) -> Result { if !rhs.is_contiguous() { - return Err(Error::RequiresContiguous { op: "embedding" }); + Err(Error::RequiresContiguous { op: "embedding" }.bt())? } else if rhs.rank() != 2 || ids.rank() != 1 { - return Err(Error::ShapeMismatchBinaryOp { + Err(Error::ShapeMismatchBinaryOp { lhs: ids.shape().clone(), rhs: rhs.shape().clone(), op: "embedding", - }); + } + .bt())? } let ids_shape = ids.shape(); let seq_len = ids_shape.r1()?; @@ -831,11 +841,12 @@ impl Tensor { /// Returns the data contained in a 1D tensor as a vector of scalar values. pub fn to_vec1(&self) -> Result> { if self.rank() != 1 { - return Err(Error::UnexpectedNumberOfDims { + Err(Error::UnexpectedNumberOfDims { expected: 1, got: self.rank(), shape: self.shape().clone(), - }); + } + .bt())? } match &*self.storage() { Storage::Cpu(cpu_storage) => { @@ -1064,11 +1075,12 @@ impl Tensor { pub fn t(&self) -> Result { let rank = self.rank(); if rank < 2 { - return Err(Error::UnexpectedNumberOfDims { + Err(Error::UnexpectedNumberOfDims { expected: 2, got: rank, shape: self.shape().clone(), - }); + } + .bt())? } self.transpose(rank - 2, rank - 1) } @@ -1278,7 +1290,8 @@ impl Tensor { lhs: self.shape().clone(), rhs: shape, op: "reshape", - }); + } + .bt()); } let op = if self.track_op() { Some(Op::Reshape(self.clone())) @@ -1370,7 +1383,7 @@ impl Tensor { /// ``` pub fn stack, D: Dim>(args: &[A], dim: D) -> Result { if args.is_empty() { - return Err(Error::OpRequiresAtLeastOneTensor { op: "stack" }); + Err(Error::OpRequiresAtLeastOneTensor { op: "stack" }.bt())? } let dim = dim.to_index_plus_one(args[0].as_ref().shape(), "stack")?; let args = args @@ -1399,7 +1412,7 @@ impl Tensor { /// ``` pub fn cat, D: Dim>(args: &[A], dim: D) -> Result { if args.is_empty() { - return Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }); + Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())? } let arg0 = args[0].as_ref(); if args.len() == 1 { @@ -1425,7 +1438,7 @@ impl Tensor { fn cat0>(args: &[A]) -> Result { if args.is_empty() { - return Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }); + Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())? } let arg0 = args[0].as_ref(); if args.len() == 1 { @@ -1442,19 +1455,21 @@ impl Tensor { let arg = arg.as_ref(); if arg.dtype() != dtype { // TODO: Improve the error message. - return Err(Error::DTypeMismatchBinaryOp { + Err(Error::DTypeMismatchBinaryOp { lhs: dtype, rhs: arg.dtype(), op: "cat", - }); + } + .bt())? } if arg.device().location() != device.location() { // TODO: Improve the error message. - return Err(Error::DeviceMismatchBinaryOp { + Err(Error::DeviceMismatchBinaryOp { lhs: device.location(), rhs: arg.device().location(), op: "cat", - }); + } + .bt())? } let mut mismatch = arg.rank() != rank; for (dim_idx, (v1, v2)) in arg0 @@ -1474,12 +1489,13 @@ impl Tensor { } } if mismatch { - return Err(Error::ShapeMismatchCat { + Err(Error::ShapeMismatchCat { dim: 0, // TODO: not the appropriate error message first_shape: arg0.shape().clone(), n: arg_idx + 1, nth_shape: arg.shape().clone(), - }); + } + .bt())? } let next_offset = offsets.last().unwrap() + arg.elem_count(); offsets.push(next_offset); diff --git a/candle-core/src/variable.rs b/candle-core/src/variable.rs index 0ae16c64..e26f1420 100644 --- a/candle-core/src/variable.rs +++ b/candle-core/src/variable.rs @@ -9,6 +9,12 @@ use crate::{DType, Device, Error, Result, Shape, Tensor}; #[derive(Clone, Debug)] pub struct Var(Tensor); +impl std::fmt::Display for Var { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + std::fmt::Display::fmt(&self.0, f) + } +} + impl std::ops::Deref for Var { type Target = Tensor; @@ -90,12 +96,12 @@ impl Var { pub fn set(&self, src: &Tensor) -> Result<()> { if self.same_storage(src) { let msg = "cannot set a variable to a tensor that is derived from its value"; - Err(Error::CannotSetVar { msg })? + Err(Error::CannotSetVar { msg }.bt())? } let (mut dst, layout) = self.storage_mut_and_layout(); if !layout.is_contiguous() { let msg = "cannot set a non-contiguous variable"; - Err(Error::CannotSetVar { msg })? + Err(Error::CannotSetVar { msg }.bt())? } let (src, src_l) = src.storage_and_layout(); if layout.shape() != src_l.shape() { @@ -103,7 +109,8 @@ impl Var { lhs: layout.shape().clone(), rhs: src_l.shape().clone(), op: "set", - })? + } + .bt())? } src.copy_strided_src(&mut dst, layout.start_offset(), src_l)?; Ok(()) diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs index aa2ec401..a6cb53e5 100644 --- a/candle-nn/src/var_builder.rs +++ b/candle-nn/src/var_builder.rs @@ -146,19 +146,28 @@ impl<'a> VarBuilder<'a> { Tensors::Zeros => Tensor::zeros(&s, data.dtype, &data.device)?.contiguous()?, Tensors::TensorMap(ts) => ts .get(&path) - .ok_or_else(|| Error::CannotFindTensor { - path: path.to_string(), + .ok_or_else(|| { + Error::CannotFindTensor { + path: path.to_string(), + } + .bt() })? .clone(), - Tensors::Npz(npz) => npz.get(&path)?.ok_or_else(|| Error::CannotFindTensor { - path: path.to_string(), + Tensors::Npz(npz) => npz.get(&path)?.ok_or_else(|| { + Error::CannotFindTensor { + path: path.to_string(), + } + .bt() })?, Tensors::SafeTensorWithRouting { routing, safetensors, } => { - let index = routing.get(&path).ok_or_else(|| Error::CannotFindTensor { - path: path.to_string(), + let index = routing.get(&path).ok_or_else(|| { + Error::CannotFindTensor { + path: path.to_string(), + } + .bt() })?; safetensors[*index] .tensor(&path, &data.device)? @@ -170,7 +179,8 @@ impl<'a> VarBuilder<'a> { msg: format!("shape mismatch for {path}"), expected: s, got: tensor.shape().clone(), - })? + } + .bt())? } Ok(tensor) }