Add backtrace information to errors where relevant. (#166)

* Add backtrace information to errors where relevant.

* More backtrace information.

* Add to the FAQ.
This commit is contained in:
Laurent Mazare
2023-07-14 09:31:25 +01:00
committed by GitHub
parent a2f72edc0d
commit d88b6cdca9
11 changed files with 153 additions and 73 deletions

View File

@ -101,3 +101,8 @@ can try adding the following at the top of your binary:
``` ```
extern crate intel_mkl_src; extern crate intel_mkl_src;
``` ```
### How to know where an error comes from.
You can set `RUST_BACKTRACE=1` to be provided with backtraces when a candle
error is generated.

View File

@ -56,7 +56,8 @@ trait Map2 {
lhs: v1.dtype(), lhs: v1.dtype(),
rhs: v2.dtype(), rhs: v2.dtype(),
op: Self::OP, op: Self::OP,
}), }
.bt()),
} }
} }
} }
@ -168,11 +169,12 @@ impl<'a> Map1 for Embedding<'a> {
for index in self.ids_l.strided_index() { for index in self.ids_l.strided_index() {
let index = self.ids[index].try_into()?; let index = self.ids[index].try_into()?;
if index >= self.vocab_size { if index >= self.vocab_size {
return Err(Error::InvalidIndex { Err(Error::InvalidIndex {
index, index,
vocab_size: self.vocab_size, vocab_size: self.vocab_size,
op: "take", op: "take",
}); }
.bt())?
} else { } else {
let hidden_size = self.hidden_size; let hidden_size = self.hidden_size;
values.extend(&vs[hidden_size * index..hidden_size * (index + 1)]); values.extend(&vs[hidden_size * index..hidden_size * (index + 1)]);
@ -273,6 +275,7 @@ impl MatMul {
bmnk: self.0, bmnk: self.0,
msg, msg,
})) }))
.bt()
} }
} }
@ -483,7 +486,7 @@ impl Map2 for MatMul {
} }
} }
} }
dtype => Err(Error::UnsupportedDTypeForOp(dtype, "matmul"))?, dtype => Err(Error::UnsupportedDTypeForOp(dtype, "matmul").bt())?,
} }
Ok(dst) Ok(dst)
} }
@ -748,8 +751,8 @@ impl BackendStorage for CpuStorage {
let data = unary_map(storage, layout, |v| elu(v, alpha)); let data = unary_map(storage, layout, |v| elu(v, alpha));
Ok(Self::F64(data)) Ok(Self::F64(data))
} }
Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "elu")), Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "elu").bt()),
Self::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "elu")), Self::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "elu").bt()),
} }
} }
@ -814,7 +817,8 @@ impl BackendStorage for CpuStorage {
lhs: self.dtype(), lhs: self.dtype(),
rhs: rhs.dtype(), rhs: rhs.dtype(),
op: B::NAME, op: B::NAME,
}) }
.bt())
} }
} }
} }
@ -833,7 +837,8 @@ impl BackendStorage for CpuStorage {
lhs: self.dtype(), lhs: self.dtype(),
rhs: dst.dtype(), rhs: dst.dtype(),
op: "copy_strided", op: "copy_strided",
}); }
.bt());
} }
} }
Ok(()) Ok(())
@ -923,7 +928,7 @@ impl BackendDevice for CpuDevice {
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
match dtype { match dtype {
DType::U8 | DType::U32 | DType::BF16 | DType::F16 => { DType::U8 | DType::U32 | DType::BF16 | DType::F16 => {
Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal")) Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal").bt())
} }
DType::F32 => { DType::F32 => {
let mut data = Vec::new(); let mut data = Vec::new();
@ -953,7 +958,7 @@ impl BackendDevice for CpuDevice {
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
match dtype { match dtype {
DType::U8 | DType::U32 | DType::BF16 | DType::F16 => { DType::U8 | DType::U32 | DType::BF16 | DType::F16 => {
Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal")) Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal").bt())
} }
DType::F32 => { DType::F32 => {
let mut data = Vec::new(); let mut data = Vec::new();

View File

@ -58,7 +58,7 @@ pub enum CudaError {
impl From<CudaError> for crate::Error { impl From<CudaError> for crate::Error {
fn from(val: CudaError) -> Self { fn from(val: CudaError) -> Self {
crate::Error::Cuda(Box::new(val)) crate::Error::Cuda(Box::new(val)).bt()
} }
} }

View File

@ -92,7 +92,8 @@ macro_rules! with_dtype {
expected: DType::$dtype, expected: DType::$dtype,
got: s.dtype(), got: s.dtype(),
msg: "unexpected dtype", msg: "unexpected dtype",
}), }
.bt()),
} }
} }
@ -103,7 +104,8 @@ macro_rules! with_dtype {
expected: DType::$dtype, expected: DType::$dtype,
got: s.dtype(), got: s.dtype(),
msg: "unexpected dtype", msg: "unexpected dtype",
}), }
.bt()),
} }
} }
} }

View File

@ -170,6 +170,12 @@ pub enum Error {
#[error(transparent)] #[error(transparent)]
Wrapped(Box<dyn std::error::Error + Send + Sync>), Wrapped(Box<dyn std::error::Error + Send + Sync>),
#[error("{inner}\n{backtrace}")]
WithBacktrace {
inner: Box<Self>,
backtrace: Box<std::backtrace::Backtrace>,
},
} }
pub type Result<T> = std::result::Result<T, Error>; pub type Result<T> = std::result::Result<T, Error>;
@ -178,4 +184,16 @@ impl Error {
pub fn wrap(err: impl std::error::Error + Send + Sync + 'static) -> Self { pub fn wrap(err: impl std::error::Error + Send + Sync + 'static) -> Self {
Self::Wrapped(Box::new(err)) Self::Wrapped(Box::new(err))
} }
pub fn bt(self) -> Self {
let backtrace = std::backtrace::Backtrace::capture();
match backtrace.status() {
std::backtrace::BacktraceStatus::Disabled
| std::backtrace::BacktraceStatus::Unsupported => self,
_ => Self::WithBacktrace {
inner: Box::new(self),
backtrace: Box::new(backtrace),
},
}
}
} }

View File

@ -67,7 +67,8 @@ impl Layout {
shape: self.shape().clone(), shape: self.shape().clone(),
dim: dim as i32, dim: dim as i32,
op: "narrow", op: "narrow",
})? }
.bt())?
} }
if start + len > dims[dim] { if start + len > dims[dim] {
Err(Error::NarrowInvalidArgs { Err(Error::NarrowInvalidArgs {
@ -76,7 +77,8 @@ impl Layout {
start, start,
len, len,
msg: "start + len > dim_len", msg: "start + len > dim_len",
})? }
.bt())?
} }
let mut dims = dims.to_vec(); let mut dims = dims.to_vec();
dims[dim] = len; dims[dim] = len;
@ -90,11 +92,12 @@ impl Layout {
pub(crate) fn transpose(&self, dim1: usize, dim2: usize) -> Result<Self> { pub(crate) fn transpose(&self, dim1: usize, dim2: usize) -> Result<Self> {
let rank = self.shape.rank(); let rank = self.shape.rank();
if rank <= dim1 || rank <= dim2 { if rank <= dim1 || rank <= dim2 {
return Err(Error::UnexpectedNumberOfDims { Err(Error::UnexpectedNumberOfDims {
expected: usize::max(dim1, dim2), expected: usize::max(dim1, dim2),
got: rank, got: rank,
shape: self.shape().clone(), shape: self.shape().clone(),
}); }
.bt())?
} }
let mut stride = self.stride().to_vec(); let mut stride = self.stride().to_vec();
let mut dims = self.shape().dims().to_vec(); let mut dims = self.shape().dims().to_vec();
@ -110,10 +113,11 @@ impl Layout {
pub fn broadcast_as<S: Into<Shape>>(&self, shape: S) -> Result<Self> { pub fn broadcast_as<S: Into<Shape>>(&self, shape: S) -> Result<Self> {
let shape = shape.into(); let shape = shape.into();
if shape.rank() < self.shape().rank() { if shape.rank() < self.shape().rank() {
Err(Error::BroadcastIncompatibleShapes { return Err(Error::BroadcastIncompatibleShapes {
src_shape: self.shape().clone(), src_shape: self.shape().clone(),
dst_shape: shape.clone(), dst_shape: shape,
})? }
.bt());
} }
let added_dims = shape.rank() - self.shape().rank(); let added_dims = shape.rank() - self.shape().rank();
let mut stride = vec![0; added_dims]; let mut stride = vec![0; added_dims];
@ -127,7 +131,8 @@ impl Layout {
return Err(Error::BroadcastIncompatibleShapes { return Err(Error::BroadcastIncompatibleShapes {
src_shape: self.shape().clone(), src_shape: self.shape().clone(),
dst_shape: shape, dst_shape: shape,
}); }
.bt());
} else { } else {
0 0
}; };

View File

@ -79,7 +79,8 @@ macro_rules! extract_dims {
expected: $cnt, expected: $cnt,
got: self.0.len(), got: self.0.len(),
shape: self.clone(), shape: self.clone(),
}) }
.bt())
} else { } else {
Ok($dims(&self.0)) Ok($dims(&self.0))
} }
@ -196,7 +197,8 @@ impl Dim for usize {
shape: shape.clone(), shape: shape.clone(),
dim: dim as i32, dim: dim as i32,
op, op,
})? }
.bt())?
} else { } else {
Ok(dim) Ok(dim)
} }
@ -209,7 +211,8 @@ impl Dim for usize {
shape: shape.clone(), shape: shape.clone(),
dim: dim as i32, dim: dim as i32,
op, op,
})? }
.bt())?
} else { } else {
Ok(dim) Ok(dim)
} }
@ -233,6 +236,7 @@ impl D {
dim, dim,
op, op,
} }
.bt()
} }
} }
@ -267,14 +271,16 @@ pub trait Dims: Sized {
shape: shape.clone(), shape: shape.clone(),
dims: dims.clone(), dims: dims.clone(),
op, op,
})? }
.bt())?
} }
if dim >= shape.rank() { if dim >= shape.rank() {
Err(Error::DimOutOfRange { Err(Error::DimOutOfRange {
shape: shape.clone(), shape: shape.clone(),
dim: dim as i32, dim: dim as i32,
op, op,
})? }
.bt())?
} }
} }
Ok(dims) Ok(dims)

View File

@ -38,7 +38,7 @@ impl Storage {
let lhs = self.device().location(); let lhs = self.device().location();
let rhs = rhs.device().location(); let rhs = rhs.device().location();
if lhs != rhs { if lhs != rhs {
Err(Error::DeviceMismatchBinaryOp { lhs, rhs, op }) Err(Error::DeviceMismatchBinaryOp { lhs, rhs, op }.bt())
} else { } else {
Ok(()) Ok(())
} }
@ -48,7 +48,7 @@ impl Storage {
let lhs = self.dtype(); let lhs = self.dtype();
let rhs = rhs.dtype(); let rhs = rhs.dtype();
if lhs != rhs { if lhs != rhs {
Err(Error::DTypeMismatchBinaryOp { lhs, rhs, op }) Err(Error::DTypeMismatchBinaryOp { lhs, rhs, op }.bt())
} else { } else {
Ok(()) Ok(())
} }
@ -153,7 +153,8 @@ impl Storage {
lhs: lhs.device().location(), lhs: lhs.device().location(),
rhs: rhs.device().location(), rhs: rhs.device().location(),
op: B::NAME, op: B::NAME,
}) }
.bt())
} }
} }
} }
@ -180,7 +181,8 @@ impl Storage {
lhs: lhs.device().location(), lhs: lhs.device().location(),
rhs: rhs.device().location(), rhs: rhs.device().location(),
op: "conv1d", op: "conv1d",
}), }
.bt()),
} }
} }
@ -208,7 +210,8 @@ impl Storage {
lhs: lhs.device().location(), lhs: lhs.device().location(),
rhs: rhs.device().location(), rhs: rhs.device().location(),
op: "where", op: "where",
}), }
.bt()),
} }
} }
@ -227,7 +230,8 @@ impl Storage {
lhs: lhs.device().location(), lhs: lhs.device().location(),
rhs: rhs.device().location(), rhs: rhs.device().location(),
op: "embedding", op: "embedding",
}), }
.bt()),
} }
} }
@ -253,7 +257,8 @@ impl Storage {
lhs: lhs.device().location(), lhs: lhs.device().location(),
rhs: rhs.device().location(), rhs: rhs.device().location(),
op: "matmul", op: "matmul",
}), }
.bt()),
} }
} }
@ -271,7 +276,8 @@ impl Storage {
lhs: lhs.device().location(), lhs: lhs.device().location(),
rhs: rhs.device().location(), rhs: rhs.device().location(),
op: "copy", op: "copy",
}), }
.bt()),
} }
} }
} }

View File

@ -281,7 +281,7 @@ impl Tensor {
let n: usize = shape.elem_count(); let n: usize = shape.elem_count();
let buffer_size: usize = array.shape()?.elem_count(); let buffer_size: usize = array.shape()?.elem_count();
if buffer_size != n { if buffer_size != n {
return Err(Error::ShapeMismatch { buffer_size, shape }); return Err(Error::ShapeMismatch { buffer_size, shape }.bt());
} }
let storage = device.storage(array)?; let storage = device.storage(array)?;
Ok(from_storage(storage, shape, None, is_variable)) Ok(from_storage(storage, shape, None, is_variable))
@ -336,7 +336,7 @@ impl Tensor {
let shape = shape.into(); let shape = shape.into();
let buffer_size = data.len(); let buffer_size = data.len();
if buffer_size != shape.elem_count() { if buffer_size != shape.elem_count() {
return Err(Error::ShapeMismatch { buffer_size, shape }); return Err(Error::ShapeMismatch { buffer_size, shape }.bt());
} }
let storage = device.storage_owned(data)?; let storage = device.storage_owned(data)?;
Ok(from_storage(storage, shape, None, is_variable)) Ok(from_storage(storage, shape, None, is_variable))
@ -398,7 +398,8 @@ impl Tensor {
lhs: self.shape().clone(), lhs: self.shape().clone(),
rhs: rhs.shape().clone(), rhs: rhs.shape().clone(),
op, op,
})? }
.bt())?
} }
} }
Ok(Shape::from(bcast_dims)) Ok(Shape::from(bcast_dims))
@ -412,7 +413,8 @@ impl Tensor {
lhs: lhs.clone(), lhs: lhs.clone(),
rhs: rhs.clone(), rhs: rhs.clone(),
op, op,
}) }
.bt())
} else { } else {
Ok(lhs) Ok(lhs)
} }
@ -450,11 +452,12 @@ impl Tensor {
/// dimensions, an error is returned instead. /// dimensions, an error is returned instead.
pub fn to_scalar<S: crate::WithDType>(&self) -> Result<S> { pub fn to_scalar<S: crate::WithDType>(&self) -> Result<S> {
if self.rank() != 0 { if self.rank() != 0 {
return Err(Error::UnexpectedNumberOfDims { Err(Error::UnexpectedNumberOfDims {
expected: 0, expected: 0,
got: self.rank(), got: self.rank(),
shape: self.shape().clone(), shape: self.shape().clone(),
}); }
.bt())?
} }
let from_cpu_storage = |cpu_storage: &crate::CpuStorage| { let from_cpu_storage = |cpu_storage: &crate::CpuStorage| {
let data = S::cpu_storage_as_slice(cpu_storage)?; let data = S::cpu_storage_as_slice(cpu_storage)?;
@ -508,7 +511,8 @@ impl Tensor {
shape: self.shape().clone(), shape: self.shape().clone(),
dim: dim as i32, dim: dim as i32,
op, op,
})? }
.bt())?
} else { } else {
Ok(()) Ok(())
} }
@ -526,7 +530,8 @@ impl Tensor {
start, start,
len, len,
msg: "start + len > dim_len", msg: "start + len > dim_len",
})? }
.bt())?
} }
if start == 0 && dims[dim] == len { if start == 0 && dims[dim] == len {
Ok(self.clone()) Ok(self.clone())
@ -666,7 +671,8 @@ impl Tensor {
padding, padding,
stride, stride,
msg: "input rank is not 2 or 3", msg: "input rank is not 2 or 3",
})?, }
.bt())?,
}; };
if c_in != c_in_k { if c_in != c_in_k {
Err(Error::Conv1dInvalidArgs { Err(Error::Conv1dInvalidArgs {
@ -675,7 +681,8 @@ impl Tensor {
padding, padding,
stride, stride,
msg: "the number of in-channels on the input doesn't match the kernel size", msg: "the number of in-channels on the input doesn't match the kernel size",
})? }
.bt())?
} }
let params = crate::conv::ParamsConv1D { let params = crate::conv::ParamsConv1D {
b_size, b_size,
@ -722,7 +729,8 @@ impl Tensor {
lhs: self.shape().clone(), lhs: self.shape().clone(),
rhs: rhs.shape().clone(), rhs: rhs.shape().clone(),
op: "matmul", op: "matmul",
})? }
.bt())?
} }
let m = a_dims[dim - 2]; let m = a_dims[dim - 2];
@ -738,7 +746,8 @@ impl Tensor {
lhs: self.shape().clone(), lhs: self.shape().clone(),
rhs: rhs.shape().clone(), rhs: rhs.shape().clone(),
op: "matmul", op: "matmul",
})? }
.bt())?
} }
let storage = self.storage().matmul( let storage = self.storage().matmul(
@ -801,13 +810,14 @@ impl Tensor {
/// ``` /// ```
pub fn embedding(ids: &Self, rhs: &Self) -> Result<Self> { pub fn embedding(ids: &Self, rhs: &Self) -> Result<Self> {
if !rhs.is_contiguous() { if !rhs.is_contiguous() {
return Err(Error::RequiresContiguous { op: "embedding" }); Err(Error::RequiresContiguous { op: "embedding" }.bt())?
} else if rhs.rank() != 2 || ids.rank() != 1 { } else if rhs.rank() != 2 || ids.rank() != 1 {
return Err(Error::ShapeMismatchBinaryOp { Err(Error::ShapeMismatchBinaryOp {
lhs: ids.shape().clone(), lhs: ids.shape().clone(),
rhs: rhs.shape().clone(), rhs: rhs.shape().clone(),
op: "embedding", op: "embedding",
}); }
.bt())?
} }
let ids_shape = ids.shape(); let ids_shape = ids.shape();
let seq_len = ids_shape.r1()?; let seq_len = ids_shape.r1()?;
@ -831,11 +841,12 @@ impl Tensor {
/// Returns the data contained in a 1D tensor as a vector of scalar values. /// Returns the data contained in a 1D tensor as a vector of scalar values.
pub fn to_vec1<S: crate::WithDType>(&self) -> Result<Vec<S>> { pub fn to_vec1<S: crate::WithDType>(&self) -> Result<Vec<S>> {
if self.rank() != 1 { if self.rank() != 1 {
return Err(Error::UnexpectedNumberOfDims { Err(Error::UnexpectedNumberOfDims {
expected: 1, expected: 1,
got: self.rank(), got: self.rank(),
shape: self.shape().clone(), shape: self.shape().clone(),
}); }
.bt())?
} }
match &*self.storage() { match &*self.storage() {
Storage::Cpu(cpu_storage) => { Storage::Cpu(cpu_storage) => {
@ -1064,11 +1075,12 @@ impl Tensor {
pub fn t(&self) -> Result<Tensor> { pub fn t(&self) -> Result<Tensor> {
let rank = self.rank(); let rank = self.rank();
if rank < 2 { if rank < 2 {
return Err(Error::UnexpectedNumberOfDims { Err(Error::UnexpectedNumberOfDims {
expected: 2, expected: 2,
got: rank, got: rank,
shape: self.shape().clone(), shape: self.shape().clone(),
}); }
.bt())?
} }
self.transpose(rank - 2, rank - 1) self.transpose(rank - 2, rank - 1)
} }
@ -1278,7 +1290,8 @@ impl Tensor {
lhs: self.shape().clone(), lhs: self.shape().clone(),
rhs: shape, rhs: shape,
op: "reshape", op: "reshape",
}); }
.bt());
} }
let op = if self.track_op() { let op = if self.track_op() {
Some(Op::Reshape(self.clone())) Some(Op::Reshape(self.clone()))
@ -1370,7 +1383,7 @@ impl Tensor {
/// ``` /// ```
pub fn stack<A: AsRef<Tensor>, D: Dim>(args: &[A], dim: D) -> Result<Self> { pub fn stack<A: AsRef<Tensor>, D: Dim>(args: &[A], dim: D) -> Result<Self> {
if args.is_empty() { if args.is_empty() {
return Err(Error::OpRequiresAtLeastOneTensor { op: "stack" }); Err(Error::OpRequiresAtLeastOneTensor { op: "stack" }.bt())?
} }
let dim = dim.to_index_plus_one(args[0].as_ref().shape(), "stack")?; let dim = dim.to_index_plus_one(args[0].as_ref().shape(), "stack")?;
let args = args let args = args
@ -1399,7 +1412,7 @@ impl Tensor {
/// ``` /// ```
pub fn cat<A: AsRef<Tensor>, D: Dim>(args: &[A], dim: D) -> Result<Self> { pub fn cat<A: AsRef<Tensor>, D: Dim>(args: &[A], dim: D) -> Result<Self> {
if args.is_empty() { if args.is_empty() {
return Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }); Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())?
} }
let arg0 = args[0].as_ref(); let arg0 = args[0].as_ref();
if args.len() == 1 { if args.len() == 1 {
@ -1425,7 +1438,7 @@ impl Tensor {
fn cat0<A: AsRef<Tensor>>(args: &[A]) -> Result<Self> { fn cat0<A: AsRef<Tensor>>(args: &[A]) -> Result<Self> {
if args.is_empty() { if args.is_empty() {
return Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }); Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())?
} }
let arg0 = args[0].as_ref(); let arg0 = args[0].as_ref();
if args.len() == 1 { if args.len() == 1 {
@ -1442,19 +1455,21 @@ impl Tensor {
let arg = arg.as_ref(); let arg = arg.as_ref();
if arg.dtype() != dtype { if arg.dtype() != dtype {
// TODO: Improve the error message. // TODO: Improve the error message.
return Err(Error::DTypeMismatchBinaryOp { Err(Error::DTypeMismatchBinaryOp {
lhs: dtype, lhs: dtype,
rhs: arg.dtype(), rhs: arg.dtype(),
op: "cat", op: "cat",
}); }
.bt())?
} }
if arg.device().location() != device.location() { if arg.device().location() != device.location() {
// TODO: Improve the error message. // TODO: Improve the error message.
return Err(Error::DeviceMismatchBinaryOp { Err(Error::DeviceMismatchBinaryOp {
lhs: device.location(), lhs: device.location(),
rhs: arg.device().location(), rhs: arg.device().location(),
op: "cat", op: "cat",
}); }
.bt())?
} }
let mut mismatch = arg.rank() != rank; let mut mismatch = arg.rank() != rank;
for (dim_idx, (v1, v2)) in arg0 for (dim_idx, (v1, v2)) in arg0
@ -1474,12 +1489,13 @@ impl Tensor {
} }
} }
if mismatch { if mismatch {
return Err(Error::ShapeMismatchCat { Err(Error::ShapeMismatchCat {
dim: 0, // TODO: not the appropriate error message dim: 0, // TODO: not the appropriate error message
first_shape: arg0.shape().clone(), first_shape: arg0.shape().clone(),
n: arg_idx + 1, n: arg_idx + 1,
nth_shape: arg.shape().clone(), nth_shape: arg.shape().clone(),
}); }
.bt())?
} }
let next_offset = offsets.last().unwrap() + arg.elem_count(); let next_offset = offsets.last().unwrap() + arg.elem_count();
offsets.push(next_offset); offsets.push(next_offset);

View File

@ -9,6 +9,12 @@ use crate::{DType, Device, Error, Result, Shape, Tensor};
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct Var(Tensor); pub struct Var(Tensor);
impl std::fmt::Display for Var {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
std::fmt::Display::fmt(&self.0, f)
}
}
impl std::ops::Deref for Var { impl std::ops::Deref for Var {
type Target = Tensor; type Target = Tensor;
@ -90,12 +96,12 @@ impl Var {
pub fn set(&self, src: &Tensor) -> Result<()> { pub fn set(&self, src: &Tensor) -> Result<()> {
if self.same_storage(src) { if self.same_storage(src) {
let msg = "cannot set a variable to a tensor that is derived from its value"; let msg = "cannot set a variable to a tensor that is derived from its value";
Err(Error::CannotSetVar { msg })? Err(Error::CannotSetVar { msg }.bt())?
} }
let (mut dst, layout) = self.storage_mut_and_layout(); let (mut dst, layout) = self.storage_mut_and_layout();
if !layout.is_contiguous() { if !layout.is_contiguous() {
let msg = "cannot set a non-contiguous variable"; let msg = "cannot set a non-contiguous variable";
Err(Error::CannotSetVar { msg })? Err(Error::CannotSetVar { msg }.bt())?
} }
let (src, src_l) = src.storage_and_layout(); let (src, src_l) = src.storage_and_layout();
if layout.shape() != src_l.shape() { if layout.shape() != src_l.shape() {
@ -103,7 +109,8 @@ impl Var {
lhs: layout.shape().clone(), lhs: layout.shape().clone(),
rhs: src_l.shape().clone(), rhs: src_l.shape().clone(),
op: "set", op: "set",
})? }
.bt())?
} }
src.copy_strided_src(&mut dst, layout.start_offset(), src_l)?; src.copy_strided_src(&mut dst, layout.start_offset(), src_l)?;
Ok(()) Ok(())

View File

@ -146,19 +146,28 @@ impl<'a> VarBuilder<'a> {
Tensors::Zeros => Tensor::zeros(&s, data.dtype, &data.device)?.contiguous()?, Tensors::Zeros => Tensor::zeros(&s, data.dtype, &data.device)?.contiguous()?,
Tensors::TensorMap(ts) => ts Tensors::TensorMap(ts) => ts
.get(&path) .get(&path)
.ok_or_else(|| Error::CannotFindTensor { .ok_or_else(|| {
path: path.to_string(), Error::CannotFindTensor {
path: path.to_string(),
}
.bt()
})? })?
.clone(), .clone(),
Tensors::Npz(npz) => npz.get(&path)?.ok_or_else(|| Error::CannotFindTensor { Tensors::Npz(npz) => npz.get(&path)?.ok_or_else(|| {
path: path.to_string(), Error::CannotFindTensor {
path: path.to_string(),
}
.bt()
})?, })?,
Tensors::SafeTensorWithRouting { Tensors::SafeTensorWithRouting {
routing, routing,
safetensors, safetensors,
} => { } => {
let index = routing.get(&path).ok_or_else(|| Error::CannotFindTensor { let index = routing.get(&path).ok_or_else(|| {
path: path.to_string(), Error::CannotFindTensor {
path: path.to_string(),
}
.bt()
})?; })?;
safetensors[*index] safetensors[*index]
.tensor(&path, &data.device)? .tensor(&path, &data.device)?
@ -170,7 +179,8 @@ impl<'a> VarBuilder<'a> {
msg: format!("shape mismatch for {path}"), msg: format!("shape mismatch for {path}"),
expected: s, expected: s,
got: tensor.shape().clone(), got: tensor.shape().clone(),
})? }
.bt())?
} }
Ok(tensor) Ok(tensor)
} }