mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Add backtrace information to errors where relevant. (#166)
* Add backtrace information to errors where relevant. * More backtrace information. * Add to the FAQ.
This commit is contained in:
@ -101,3 +101,8 @@ can try adding the following at the top of your binary:
|
||||
```
|
||||
extern crate intel_mkl_src;
|
||||
```
|
||||
|
||||
### How to know where an error comes from.
|
||||
|
||||
You can set `RUST_BACKTRACE=1` to be provided with backtraces when a candle
|
||||
error is generated.
|
||||
|
@ -56,7 +56,8 @@ trait Map2 {
|
||||
lhs: v1.dtype(),
|
||||
rhs: v2.dtype(),
|
||||
op: Self::OP,
|
||||
}),
|
||||
}
|
||||
.bt()),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -168,11 +169,12 @@ impl<'a> Map1 for Embedding<'a> {
|
||||
for index in self.ids_l.strided_index() {
|
||||
let index = self.ids[index].try_into()?;
|
||||
if index >= self.vocab_size {
|
||||
return Err(Error::InvalidIndex {
|
||||
Err(Error::InvalidIndex {
|
||||
index,
|
||||
vocab_size: self.vocab_size,
|
||||
op: "take",
|
||||
});
|
||||
}
|
||||
.bt())?
|
||||
} else {
|
||||
let hidden_size = self.hidden_size;
|
||||
values.extend(&vs[hidden_size * index..hidden_size * (index + 1)]);
|
||||
@ -273,6 +275,7 @@ impl MatMul {
|
||||
bmnk: self.0,
|
||||
msg,
|
||||
}))
|
||||
.bt()
|
||||
}
|
||||
}
|
||||
|
||||
@ -483,7 +486,7 @@ impl Map2 for MatMul {
|
||||
}
|
||||
}
|
||||
}
|
||||
dtype => Err(Error::UnsupportedDTypeForOp(dtype, "matmul"))?,
|
||||
dtype => Err(Error::UnsupportedDTypeForOp(dtype, "matmul").bt())?,
|
||||
}
|
||||
Ok(dst)
|
||||
}
|
||||
@ -748,8 +751,8 @@ impl BackendStorage for CpuStorage {
|
||||
let data = unary_map(storage, layout, |v| elu(v, alpha));
|
||||
Ok(Self::F64(data))
|
||||
}
|
||||
Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "elu")),
|
||||
Self::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "elu")),
|
||||
Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "elu").bt()),
|
||||
Self::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "elu").bt()),
|
||||
}
|
||||
}
|
||||
|
||||
@ -814,7 +817,8 @@ impl BackendStorage for CpuStorage {
|
||||
lhs: self.dtype(),
|
||||
rhs: rhs.dtype(),
|
||||
op: B::NAME,
|
||||
})
|
||||
}
|
||||
.bt())
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -833,7 +837,8 @@ impl BackendStorage for CpuStorage {
|
||||
lhs: self.dtype(),
|
||||
rhs: dst.dtype(),
|
||||
op: "copy_strided",
|
||||
});
|
||||
}
|
||||
.bt());
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
@ -923,7 +928,7 @@ impl BackendDevice for CpuDevice {
|
||||
let mut rng = rand::thread_rng();
|
||||
match dtype {
|
||||
DType::U8 | DType::U32 | DType::BF16 | DType::F16 => {
|
||||
Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal"))
|
||||
Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal").bt())
|
||||
}
|
||||
DType::F32 => {
|
||||
let mut data = Vec::new();
|
||||
@ -953,7 +958,7 @@ impl BackendDevice for CpuDevice {
|
||||
let mut rng = rand::thread_rng();
|
||||
match dtype {
|
||||
DType::U8 | DType::U32 | DType::BF16 | DType::F16 => {
|
||||
Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal"))
|
||||
Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal").bt())
|
||||
}
|
||||
DType::F32 => {
|
||||
let mut data = Vec::new();
|
||||
|
@ -58,7 +58,7 @@ pub enum CudaError {
|
||||
|
||||
impl From<CudaError> for crate::Error {
|
||||
fn from(val: CudaError) -> Self {
|
||||
crate::Error::Cuda(Box::new(val))
|
||||
crate::Error::Cuda(Box::new(val)).bt()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -92,7 +92,8 @@ macro_rules! with_dtype {
|
||||
expected: DType::$dtype,
|
||||
got: s.dtype(),
|
||||
msg: "unexpected dtype",
|
||||
}),
|
||||
}
|
||||
.bt()),
|
||||
}
|
||||
}
|
||||
|
||||
@ -103,7 +104,8 @@ macro_rules! with_dtype {
|
||||
expected: DType::$dtype,
|
||||
got: s.dtype(),
|
||||
msg: "unexpected dtype",
|
||||
}),
|
||||
}
|
||||
.bt()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -170,6 +170,12 @@ pub enum Error {
|
||||
|
||||
#[error(transparent)]
|
||||
Wrapped(Box<dyn std::error::Error + Send + Sync>),
|
||||
|
||||
#[error("{inner}\n{backtrace}")]
|
||||
WithBacktrace {
|
||||
inner: Box<Self>,
|
||||
backtrace: Box<std::backtrace::Backtrace>,
|
||||
},
|
||||
}
|
||||
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
@ -178,4 +184,16 @@ impl Error {
|
||||
pub fn wrap(err: impl std::error::Error + Send + Sync + 'static) -> Self {
|
||||
Self::Wrapped(Box::new(err))
|
||||
}
|
||||
|
||||
pub fn bt(self) -> Self {
|
||||
let backtrace = std::backtrace::Backtrace::capture();
|
||||
match backtrace.status() {
|
||||
std::backtrace::BacktraceStatus::Disabled
|
||||
| std::backtrace::BacktraceStatus::Unsupported => self,
|
||||
_ => Self::WithBacktrace {
|
||||
inner: Box::new(self),
|
||||
backtrace: Box::new(backtrace),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -67,7 +67,8 @@ impl Layout {
|
||||
shape: self.shape().clone(),
|
||||
dim: dim as i32,
|
||||
op: "narrow",
|
||||
})?
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
if start + len > dims[dim] {
|
||||
Err(Error::NarrowInvalidArgs {
|
||||
@ -76,7 +77,8 @@ impl Layout {
|
||||
start,
|
||||
len,
|
||||
msg: "start + len > dim_len",
|
||||
})?
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
let mut dims = dims.to_vec();
|
||||
dims[dim] = len;
|
||||
@ -90,11 +92,12 @@ impl Layout {
|
||||
pub(crate) fn transpose(&self, dim1: usize, dim2: usize) -> Result<Self> {
|
||||
let rank = self.shape.rank();
|
||||
if rank <= dim1 || rank <= dim2 {
|
||||
return Err(Error::UnexpectedNumberOfDims {
|
||||
Err(Error::UnexpectedNumberOfDims {
|
||||
expected: usize::max(dim1, dim2),
|
||||
got: rank,
|
||||
shape: self.shape().clone(),
|
||||
});
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
let mut stride = self.stride().to_vec();
|
||||
let mut dims = self.shape().dims().to_vec();
|
||||
@ -110,10 +113,11 @@ impl Layout {
|
||||
pub fn broadcast_as<S: Into<Shape>>(&self, shape: S) -> Result<Self> {
|
||||
let shape = shape.into();
|
||||
if shape.rank() < self.shape().rank() {
|
||||
Err(Error::BroadcastIncompatibleShapes {
|
||||
return Err(Error::BroadcastIncompatibleShapes {
|
||||
src_shape: self.shape().clone(),
|
||||
dst_shape: shape.clone(),
|
||||
})?
|
||||
dst_shape: shape,
|
||||
}
|
||||
.bt());
|
||||
}
|
||||
let added_dims = shape.rank() - self.shape().rank();
|
||||
let mut stride = vec![0; added_dims];
|
||||
@ -127,7 +131,8 @@ impl Layout {
|
||||
return Err(Error::BroadcastIncompatibleShapes {
|
||||
src_shape: self.shape().clone(),
|
||||
dst_shape: shape,
|
||||
});
|
||||
}
|
||||
.bt());
|
||||
} else {
|
||||
0
|
||||
};
|
||||
|
@ -79,7 +79,8 @@ macro_rules! extract_dims {
|
||||
expected: $cnt,
|
||||
got: self.0.len(),
|
||||
shape: self.clone(),
|
||||
})
|
||||
}
|
||||
.bt())
|
||||
} else {
|
||||
Ok($dims(&self.0))
|
||||
}
|
||||
@ -196,7 +197,8 @@ impl Dim for usize {
|
||||
shape: shape.clone(),
|
||||
dim: dim as i32,
|
||||
op,
|
||||
})?
|
||||
}
|
||||
.bt())?
|
||||
} else {
|
||||
Ok(dim)
|
||||
}
|
||||
@ -209,7 +211,8 @@ impl Dim for usize {
|
||||
shape: shape.clone(),
|
||||
dim: dim as i32,
|
||||
op,
|
||||
})?
|
||||
}
|
||||
.bt())?
|
||||
} else {
|
||||
Ok(dim)
|
||||
}
|
||||
@ -233,6 +236,7 @@ impl D {
|
||||
dim,
|
||||
op,
|
||||
}
|
||||
.bt()
|
||||
}
|
||||
}
|
||||
|
||||
@ -267,14 +271,16 @@ pub trait Dims: Sized {
|
||||
shape: shape.clone(),
|
||||
dims: dims.clone(),
|
||||
op,
|
||||
})?
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
if dim >= shape.rank() {
|
||||
Err(Error::DimOutOfRange {
|
||||
shape: shape.clone(),
|
||||
dim: dim as i32,
|
||||
op,
|
||||
})?
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
}
|
||||
Ok(dims)
|
||||
|
@ -38,7 +38,7 @@ impl Storage {
|
||||
let lhs = self.device().location();
|
||||
let rhs = rhs.device().location();
|
||||
if lhs != rhs {
|
||||
Err(Error::DeviceMismatchBinaryOp { lhs, rhs, op })
|
||||
Err(Error::DeviceMismatchBinaryOp { lhs, rhs, op }.bt())
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
@ -48,7 +48,7 @@ impl Storage {
|
||||
let lhs = self.dtype();
|
||||
let rhs = rhs.dtype();
|
||||
if lhs != rhs {
|
||||
Err(Error::DTypeMismatchBinaryOp { lhs, rhs, op })
|
||||
Err(Error::DTypeMismatchBinaryOp { lhs, rhs, op }.bt())
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
@ -153,7 +153,8 @@ impl Storage {
|
||||
lhs: lhs.device().location(),
|
||||
rhs: rhs.device().location(),
|
||||
op: B::NAME,
|
||||
})
|
||||
}
|
||||
.bt())
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -180,7 +181,8 @@ impl Storage {
|
||||
lhs: lhs.device().location(),
|
||||
rhs: rhs.device().location(),
|
||||
op: "conv1d",
|
||||
}),
|
||||
}
|
||||
.bt()),
|
||||
}
|
||||
}
|
||||
|
||||
@ -208,7 +210,8 @@ impl Storage {
|
||||
lhs: lhs.device().location(),
|
||||
rhs: rhs.device().location(),
|
||||
op: "where",
|
||||
}),
|
||||
}
|
||||
.bt()),
|
||||
}
|
||||
}
|
||||
|
||||
@ -227,7 +230,8 @@ impl Storage {
|
||||
lhs: lhs.device().location(),
|
||||
rhs: rhs.device().location(),
|
||||
op: "embedding",
|
||||
}),
|
||||
}
|
||||
.bt()),
|
||||
}
|
||||
}
|
||||
|
||||
@ -253,7 +257,8 @@ impl Storage {
|
||||
lhs: lhs.device().location(),
|
||||
rhs: rhs.device().location(),
|
||||
op: "matmul",
|
||||
}),
|
||||
}
|
||||
.bt()),
|
||||
}
|
||||
}
|
||||
|
||||
@ -271,7 +276,8 @@ impl Storage {
|
||||
lhs: lhs.device().location(),
|
||||
rhs: rhs.device().location(),
|
||||
op: "copy",
|
||||
}),
|
||||
}
|
||||
.bt()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -281,7 +281,7 @@ impl Tensor {
|
||||
let n: usize = shape.elem_count();
|
||||
let buffer_size: usize = array.shape()?.elem_count();
|
||||
if buffer_size != n {
|
||||
return Err(Error::ShapeMismatch { buffer_size, shape });
|
||||
return Err(Error::ShapeMismatch { buffer_size, shape }.bt());
|
||||
}
|
||||
let storage = device.storage(array)?;
|
||||
Ok(from_storage(storage, shape, None, is_variable))
|
||||
@ -336,7 +336,7 @@ impl Tensor {
|
||||
let shape = shape.into();
|
||||
let buffer_size = data.len();
|
||||
if buffer_size != shape.elem_count() {
|
||||
return Err(Error::ShapeMismatch { buffer_size, shape });
|
||||
return Err(Error::ShapeMismatch { buffer_size, shape }.bt());
|
||||
}
|
||||
let storage = device.storage_owned(data)?;
|
||||
Ok(from_storage(storage, shape, None, is_variable))
|
||||
@ -398,7 +398,8 @@ impl Tensor {
|
||||
lhs: self.shape().clone(),
|
||||
rhs: rhs.shape().clone(),
|
||||
op,
|
||||
})?
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
}
|
||||
Ok(Shape::from(bcast_dims))
|
||||
@ -412,7 +413,8 @@ impl Tensor {
|
||||
lhs: lhs.clone(),
|
||||
rhs: rhs.clone(),
|
||||
op,
|
||||
})
|
||||
}
|
||||
.bt())
|
||||
} else {
|
||||
Ok(lhs)
|
||||
}
|
||||
@ -450,11 +452,12 @@ impl Tensor {
|
||||
/// dimensions, an error is returned instead.
|
||||
pub fn to_scalar<S: crate::WithDType>(&self) -> Result<S> {
|
||||
if self.rank() != 0 {
|
||||
return Err(Error::UnexpectedNumberOfDims {
|
||||
Err(Error::UnexpectedNumberOfDims {
|
||||
expected: 0,
|
||||
got: self.rank(),
|
||||
shape: self.shape().clone(),
|
||||
});
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
let from_cpu_storage = |cpu_storage: &crate::CpuStorage| {
|
||||
let data = S::cpu_storage_as_slice(cpu_storage)?;
|
||||
@ -508,7 +511,8 @@ impl Tensor {
|
||||
shape: self.shape().clone(),
|
||||
dim: dim as i32,
|
||||
op,
|
||||
})?
|
||||
}
|
||||
.bt())?
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
@ -526,7 +530,8 @@ impl Tensor {
|
||||
start,
|
||||
len,
|
||||
msg: "start + len > dim_len",
|
||||
})?
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
if start == 0 && dims[dim] == len {
|
||||
Ok(self.clone())
|
||||
@ -666,7 +671,8 @@ impl Tensor {
|
||||
padding,
|
||||
stride,
|
||||
msg: "input rank is not 2 or 3",
|
||||
})?,
|
||||
}
|
||||
.bt())?,
|
||||
};
|
||||
if c_in != c_in_k {
|
||||
Err(Error::Conv1dInvalidArgs {
|
||||
@ -675,7 +681,8 @@ impl Tensor {
|
||||
padding,
|
||||
stride,
|
||||
msg: "the number of in-channels on the input doesn't match the kernel size",
|
||||
})?
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
let params = crate::conv::ParamsConv1D {
|
||||
b_size,
|
||||
@ -722,7 +729,8 @@ impl Tensor {
|
||||
lhs: self.shape().clone(),
|
||||
rhs: rhs.shape().clone(),
|
||||
op: "matmul",
|
||||
})?
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
|
||||
let m = a_dims[dim - 2];
|
||||
@ -738,7 +746,8 @@ impl Tensor {
|
||||
lhs: self.shape().clone(),
|
||||
rhs: rhs.shape().clone(),
|
||||
op: "matmul",
|
||||
})?
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
|
||||
let storage = self.storage().matmul(
|
||||
@ -801,13 +810,14 @@ impl Tensor {
|
||||
/// ```
|
||||
pub fn embedding(ids: &Self, rhs: &Self) -> Result<Self> {
|
||||
if !rhs.is_contiguous() {
|
||||
return Err(Error::RequiresContiguous { op: "embedding" });
|
||||
Err(Error::RequiresContiguous { op: "embedding" }.bt())?
|
||||
} else if rhs.rank() != 2 || ids.rank() != 1 {
|
||||
return Err(Error::ShapeMismatchBinaryOp {
|
||||
Err(Error::ShapeMismatchBinaryOp {
|
||||
lhs: ids.shape().clone(),
|
||||
rhs: rhs.shape().clone(),
|
||||
op: "embedding",
|
||||
});
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
let ids_shape = ids.shape();
|
||||
let seq_len = ids_shape.r1()?;
|
||||
@ -831,11 +841,12 @@ impl Tensor {
|
||||
/// Returns the data contained in a 1D tensor as a vector of scalar values.
|
||||
pub fn to_vec1<S: crate::WithDType>(&self) -> Result<Vec<S>> {
|
||||
if self.rank() != 1 {
|
||||
return Err(Error::UnexpectedNumberOfDims {
|
||||
Err(Error::UnexpectedNumberOfDims {
|
||||
expected: 1,
|
||||
got: self.rank(),
|
||||
shape: self.shape().clone(),
|
||||
});
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
match &*self.storage() {
|
||||
Storage::Cpu(cpu_storage) => {
|
||||
@ -1064,11 +1075,12 @@ impl Tensor {
|
||||
pub fn t(&self) -> Result<Tensor> {
|
||||
let rank = self.rank();
|
||||
if rank < 2 {
|
||||
return Err(Error::UnexpectedNumberOfDims {
|
||||
Err(Error::UnexpectedNumberOfDims {
|
||||
expected: 2,
|
||||
got: rank,
|
||||
shape: self.shape().clone(),
|
||||
});
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
self.transpose(rank - 2, rank - 1)
|
||||
}
|
||||
@ -1278,7 +1290,8 @@ impl Tensor {
|
||||
lhs: self.shape().clone(),
|
||||
rhs: shape,
|
||||
op: "reshape",
|
||||
});
|
||||
}
|
||||
.bt());
|
||||
}
|
||||
let op = if self.track_op() {
|
||||
Some(Op::Reshape(self.clone()))
|
||||
@ -1370,7 +1383,7 @@ impl Tensor {
|
||||
/// ```
|
||||
pub fn stack<A: AsRef<Tensor>, D: Dim>(args: &[A], dim: D) -> Result<Self> {
|
||||
if args.is_empty() {
|
||||
return Err(Error::OpRequiresAtLeastOneTensor { op: "stack" });
|
||||
Err(Error::OpRequiresAtLeastOneTensor { op: "stack" }.bt())?
|
||||
}
|
||||
let dim = dim.to_index_plus_one(args[0].as_ref().shape(), "stack")?;
|
||||
let args = args
|
||||
@ -1399,7 +1412,7 @@ impl Tensor {
|
||||
/// ```
|
||||
pub fn cat<A: AsRef<Tensor>, D: Dim>(args: &[A], dim: D) -> Result<Self> {
|
||||
if args.is_empty() {
|
||||
return Err(Error::OpRequiresAtLeastOneTensor { op: "cat" });
|
||||
Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())?
|
||||
}
|
||||
let arg0 = args[0].as_ref();
|
||||
if args.len() == 1 {
|
||||
@ -1425,7 +1438,7 @@ impl Tensor {
|
||||
|
||||
fn cat0<A: AsRef<Tensor>>(args: &[A]) -> Result<Self> {
|
||||
if args.is_empty() {
|
||||
return Err(Error::OpRequiresAtLeastOneTensor { op: "cat" });
|
||||
Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())?
|
||||
}
|
||||
let arg0 = args[0].as_ref();
|
||||
if args.len() == 1 {
|
||||
@ -1442,19 +1455,21 @@ impl Tensor {
|
||||
let arg = arg.as_ref();
|
||||
if arg.dtype() != dtype {
|
||||
// TODO: Improve the error message.
|
||||
return Err(Error::DTypeMismatchBinaryOp {
|
||||
Err(Error::DTypeMismatchBinaryOp {
|
||||
lhs: dtype,
|
||||
rhs: arg.dtype(),
|
||||
op: "cat",
|
||||
});
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
if arg.device().location() != device.location() {
|
||||
// TODO: Improve the error message.
|
||||
return Err(Error::DeviceMismatchBinaryOp {
|
||||
Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: device.location(),
|
||||
rhs: arg.device().location(),
|
||||
op: "cat",
|
||||
});
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
let mut mismatch = arg.rank() != rank;
|
||||
for (dim_idx, (v1, v2)) in arg0
|
||||
@ -1474,12 +1489,13 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
if mismatch {
|
||||
return Err(Error::ShapeMismatchCat {
|
||||
Err(Error::ShapeMismatchCat {
|
||||
dim: 0, // TODO: not the appropriate error message
|
||||
first_shape: arg0.shape().clone(),
|
||||
n: arg_idx + 1,
|
||||
nth_shape: arg.shape().clone(),
|
||||
});
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
let next_offset = offsets.last().unwrap() + arg.elem_count();
|
||||
offsets.push(next_offset);
|
||||
|
@ -9,6 +9,12 @@ use crate::{DType, Device, Error, Result, Shape, Tensor};
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Var(Tensor);
|
||||
|
||||
impl std::fmt::Display for Var {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
std::fmt::Display::fmt(&self.0, f)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Deref for Var {
|
||||
type Target = Tensor;
|
||||
|
||||
@ -90,12 +96,12 @@ impl Var {
|
||||
pub fn set(&self, src: &Tensor) -> Result<()> {
|
||||
if self.same_storage(src) {
|
||||
let msg = "cannot set a variable to a tensor that is derived from its value";
|
||||
Err(Error::CannotSetVar { msg })?
|
||||
Err(Error::CannotSetVar { msg }.bt())?
|
||||
}
|
||||
let (mut dst, layout) = self.storage_mut_and_layout();
|
||||
if !layout.is_contiguous() {
|
||||
let msg = "cannot set a non-contiguous variable";
|
||||
Err(Error::CannotSetVar { msg })?
|
||||
Err(Error::CannotSetVar { msg }.bt())?
|
||||
}
|
||||
let (src, src_l) = src.storage_and_layout();
|
||||
if layout.shape() != src_l.shape() {
|
||||
@ -103,7 +109,8 @@ impl Var {
|
||||
lhs: layout.shape().clone(),
|
||||
rhs: src_l.shape().clone(),
|
||||
op: "set",
|
||||
})?
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
src.copy_strided_src(&mut dst, layout.start_offset(), src_l)?;
|
||||
Ok(())
|
||||
|
@ -146,19 +146,28 @@ impl<'a> VarBuilder<'a> {
|
||||
Tensors::Zeros => Tensor::zeros(&s, data.dtype, &data.device)?.contiguous()?,
|
||||
Tensors::TensorMap(ts) => ts
|
||||
.get(&path)
|
||||
.ok_or_else(|| Error::CannotFindTensor {
|
||||
path: path.to_string(),
|
||||
.ok_or_else(|| {
|
||||
Error::CannotFindTensor {
|
||||
path: path.to_string(),
|
||||
}
|
||||
.bt()
|
||||
})?
|
||||
.clone(),
|
||||
Tensors::Npz(npz) => npz.get(&path)?.ok_or_else(|| Error::CannotFindTensor {
|
||||
path: path.to_string(),
|
||||
Tensors::Npz(npz) => npz.get(&path)?.ok_or_else(|| {
|
||||
Error::CannotFindTensor {
|
||||
path: path.to_string(),
|
||||
}
|
||||
.bt()
|
||||
})?,
|
||||
Tensors::SafeTensorWithRouting {
|
||||
routing,
|
||||
safetensors,
|
||||
} => {
|
||||
let index = routing.get(&path).ok_or_else(|| Error::CannotFindTensor {
|
||||
path: path.to_string(),
|
||||
let index = routing.get(&path).ok_or_else(|| {
|
||||
Error::CannotFindTensor {
|
||||
path: path.to_string(),
|
||||
}
|
||||
.bt()
|
||||
})?;
|
||||
safetensors[*index]
|
||||
.tensor(&path, &data.device)?
|
||||
@ -170,7 +179,8 @@ impl<'a> VarBuilder<'a> {
|
||||
msg: format!("shape mismatch for {path}"),
|
||||
expected: s,
|
||||
got: tensor.shape().clone(),
|
||||
})?
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
Ok(tensor)
|
||||
}
|
||||
|
Reference in New Issue
Block a user