mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Cleanup the main crate error and add a couple dedicated ones (#142)
* Cosmetic cleanups to the error enum. * More error cleanup. * Proper error handling rather than panicing. * Add some conv1d dedicated error.
This commit is contained in:
@ -261,6 +261,17 @@ impl<'a> Map2 for Conv1D<'a> {
|
|||||||
|
|
||||||
struct MatMul((usize, usize, usize, usize));
|
struct MatMul((usize, usize, usize, usize));
|
||||||
|
|
||||||
|
impl MatMul {
|
||||||
|
fn striding_error(&self, lhs_l: &Layout, rhs_l: &Layout, msg: &'static str) -> Error {
|
||||||
|
Error::MatMulUnexpectedStriding(Box::new(crate::error::MatMulUnexpectedStriding {
|
||||||
|
lhs_l: lhs_l.clone(),
|
||||||
|
rhs_l: rhs_l.clone(),
|
||||||
|
bmnk: self.0,
|
||||||
|
msg,
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl Map2 for MatMul {
|
impl Map2 for MatMul {
|
||||||
const OP: &'static str = "mat_mul";
|
const OP: &'static str = "mat_mul";
|
||||||
|
|
||||||
@ -290,19 +301,13 @@ impl Map2 for MatMul {
|
|||||||
[s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
|
[s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
|
||||||
[stride] => stride,
|
[stride] => stride,
|
||||||
[] => m * k,
|
[] => m * k,
|
||||||
_ => Err(Error::UnexpectedStriding {
|
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?,
|
||||||
lhs_stride: lhs_stride.to_vec(),
|
|
||||||
rhs_stride: rhs_stride.to_vec(),
|
|
||||||
})?,
|
|
||||||
};
|
};
|
||||||
let b_skip: usize = match rhs_stride[..rank - 2] {
|
let b_skip: usize = match rhs_stride[..rank - 2] {
|
||||||
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
|
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
|
||||||
[stride] => stride,
|
[stride] => stride,
|
||||||
[] => n * k,
|
[] => n * k,
|
||||||
_ => Err(Error::UnexpectedStriding {
|
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?,
|
||||||
lhs_stride: lhs_stride.to_vec(),
|
|
||||||
rhs_stride: rhs_stride.to_vec(),
|
|
||||||
})?,
|
|
||||||
};
|
};
|
||||||
let c_skip: usize = m * n;
|
let c_skip: usize = m * n;
|
||||||
|
|
||||||
@ -369,19 +374,13 @@ impl Map2 for MatMul {
|
|||||||
[s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
|
[s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
|
||||||
[stride] => stride,
|
[stride] => stride,
|
||||||
[] => m * k,
|
[] => m * k,
|
||||||
_ => Err(Error::UnexpectedStriding {
|
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?,
|
||||||
lhs_stride: lhs_stride.to_vec(),
|
|
||||||
rhs_stride: rhs_stride.to_vec(),
|
|
||||||
})?,
|
|
||||||
};
|
};
|
||||||
let b_skip: usize = match rhs_stride[..rank - 2] {
|
let b_skip: usize = match rhs_stride[..rank - 2] {
|
||||||
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
|
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
|
||||||
[stride] => stride,
|
[stride] => stride,
|
||||||
[] => n * k,
|
[] => n * k,
|
||||||
_ => Err(Error::UnexpectedStriding {
|
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?,
|
||||||
lhs_stride: lhs_stride.to_vec(),
|
|
||||||
rhs_stride: rhs_stride.to_vec(),
|
|
||||||
})?,
|
|
||||||
};
|
};
|
||||||
let c_skip: usize = m * n;
|
let c_skip: usize = m * n;
|
||||||
|
|
||||||
@ -395,11 +394,7 @@ impl Map2 for MatMul {
|
|||||||
} else if rhs_m1 == k && rhs_m2 == 1 {
|
} else if rhs_m1 == k && rhs_m2 == 1 {
|
||||||
(k as i32, b'T')
|
(k as i32, b'T')
|
||||||
} else {
|
} else {
|
||||||
Err(Error::MatMulNonContiguous {
|
Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?
|
||||||
lhs_stride: lhs_stride.to_vec(),
|
|
||||||
rhs_stride: rhs_stride.to_vec(),
|
|
||||||
mnk: (m, n, k),
|
|
||||||
})?
|
|
||||||
};
|
};
|
||||||
// The b tensor has dims batching, m, k (lhs)
|
// The b tensor has dims batching, m, k (lhs)
|
||||||
let (ldb, transb) = if lhs_m1 == 1 && lhs_m2 == k {
|
let (ldb, transb) = if lhs_m1 == 1 && lhs_m2 == k {
|
||||||
@ -407,11 +402,7 @@ impl Map2 for MatMul {
|
|||||||
} else if lhs_m1 == m && lhs_m2 == 1 {
|
} else if lhs_m1 == m && lhs_m2 == 1 {
|
||||||
(m as i32, b'T')
|
(m as i32, b'T')
|
||||||
} else {
|
} else {
|
||||||
Err(Error::MatMulNonContiguous {
|
Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?
|
||||||
lhs_stride: lhs_stride.to_vec(),
|
|
||||||
rhs_stride: rhs_stride.to_vec(),
|
|
||||||
mnk: (m, n, k),
|
|
||||||
})?
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut dst = vec![T::zero(); b * m * n];
|
let mut dst = vec![T::zero(); b * m * n];
|
||||||
|
@ -1,8 +1,17 @@
|
|||||||
use crate::{DType, DeviceLocation, Shape};
|
use crate::{DType, DeviceLocation, Layout, Shape};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct MatMulUnexpectedStriding {
|
||||||
|
pub lhs_l: Layout,
|
||||||
|
pub rhs_l: Layout,
|
||||||
|
pub bmnk: (usize, usize, usize, usize),
|
||||||
|
pub msg: &'static str,
|
||||||
|
}
|
||||||
|
|
||||||
/// Main library error type.
|
/// Main library error type.
|
||||||
#[derive(thiserror::Error, Debug)]
|
#[derive(thiserror::Error, Debug)]
|
||||||
pub enum Error {
|
pub enum Error {
|
||||||
|
// === DType Errors ===
|
||||||
#[error("{msg}, expected: {expected:?}, got: {got:?}")]
|
#[error("{msg}, expected: {expected:?}, got: {got:?}")]
|
||||||
UnexpectedDType {
|
UnexpectedDType {
|
||||||
msg: &'static str,
|
msg: &'static str,
|
||||||
@ -10,6 +19,32 @@ pub enum Error {
|
|||||||
got: DType,
|
got: DType,
|
||||||
},
|
},
|
||||||
|
|
||||||
|
#[error("dtype mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")]
|
||||||
|
DTypeMismatchBinaryOp {
|
||||||
|
lhs: DType,
|
||||||
|
rhs: DType,
|
||||||
|
op: &'static str,
|
||||||
|
},
|
||||||
|
|
||||||
|
#[error("unsupported dtype {0:?} for op {1}")]
|
||||||
|
UnsupportedDTypeForOp(DType, &'static str),
|
||||||
|
|
||||||
|
// === Dimension Index Errors ===
|
||||||
|
#[error("{op}: dimension index {dim} out of range for {shape:?}")]
|
||||||
|
DimOutOfRange {
|
||||||
|
shape: Shape,
|
||||||
|
dim: i32,
|
||||||
|
op: &'static str,
|
||||||
|
},
|
||||||
|
|
||||||
|
// === Shape Errors ===
|
||||||
|
#[error("unexpected rank, expected: {expected}, got: {got} ({shape:?})")]
|
||||||
|
UnexpectedNumberOfDims {
|
||||||
|
expected: usize,
|
||||||
|
got: usize,
|
||||||
|
shape: Shape,
|
||||||
|
},
|
||||||
|
|
||||||
#[error("{msg}, expected: {expected:?}, got: {got:?}")]
|
#[error("{msg}, expected: {expected:?}, got: {got:?}")]
|
||||||
UnexpectedShape {
|
UnexpectedShape {
|
||||||
msg: String,
|
msg: String,
|
||||||
@ -17,40 +52,6 @@ pub enum Error {
|
|||||||
got: Shape,
|
got: Shape,
|
||||||
},
|
},
|
||||||
|
|
||||||
#[error("{op}: dimension index {dim} out of range for {shape:?}")]
|
|
||||||
DimOutOfRange {
|
|
||||||
shape: Shape,
|
|
||||||
dim: usize,
|
|
||||||
op: &'static str,
|
|
||||||
},
|
|
||||||
|
|
||||||
#[error("invalid args for narrow: {shape:?}, dim: {dim}, start: {start}, len:{len}")]
|
|
||||||
NarrowInvalidArgs {
|
|
||||||
shape: Shape,
|
|
||||||
dim: usize,
|
|
||||||
start: usize,
|
|
||||||
len: usize,
|
|
||||||
},
|
|
||||||
|
|
||||||
#[error("{op} only supports contiguous tensors")]
|
|
||||||
RequiresContiguous { op: &'static str },
|
|
||||||
|
|
||||||
#[error("{op} expects at least one tensor")]
|
|
||||||
OpRequiresAtLeastOneTensor { op: &'static str },
|
|
||||||
|
|
||||||
#[error("backward is not supported for {op}")]
|
|
||||||
BackwardNotSupported { op: &'static str },
|
|
||||||
|
|
||||||
#[error("{op} invalid index {index} with vocab {vocab_size}")]
|
|
||||||
InvalidIndex {
|
|
||||||
op: &'static str,
|
|
||||||
index: usize,
|
|
||||||
vocab_size: usize,
|
|
||||||
},
|
|
||||||
|
|
||||||
#[error("the candle crate has not been built with cuda support")]
|
|
||||||
NotCompiledWithCudaSupport,
|
|
||||||
|
|
||||||
#[error(
|
#[error(
|
||||||
"Shape mismatch, got buffer of size {buffer_size} which is compatible with shape {shape:?}"
|
"Shape mismatch, got buffer of size {buffer_size} which is compatible with shape {shape:?}"
|
||||||
)]
|
)]
|
||||||
@ -71,6 +72,7 @@ pub enum Error {
|
|||||||
nth_shape: Shape,
|
nth_shape: Shape,
|
||||||
},
|
},
|
||||||
|
|
||||||
|
// === Device Errors ===
|
||||||
#[error("device mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")]
|
#[error("device mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")]
|
||||||
DeviceMismatchBinaryOp {
|
DeviceMismatchBinaryOp {
|
||||||
lhs: DeviceLocation,
|
lhs: DeviceLocation,
|
||||||
@ -78,27 +80,56 @@ pub enum Error {
|
|||||||
op: &'static str,
|
op: &'static str,
|
||||||
},
|
},
|
||||||
|
|
||||||
#[error("dtype mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")]
|
// === Op Specific Errors ===
|
||||||
DTypeMismatchBinaryOp {
|
#[error("narrow invalid args {msg}: {shape:?}, dim: {dim}, start: {start}, len:{len}")]
|
||||||
lhs: DType,
|
NarrowInvalidArgs {
|
||||||
rhs: DType,
|
|
||||||
op: &'static str,
|
|
||||||
},
|
|
||||||
|
|
||||||
#[error("unexpected rank, expected: {expected}, got: {got} ({shape:?})")]
|
|
||||||
UnexpectedNumberOfDims {
|
|
||||||
expected: usize,
|
|
||||||
got: usize,
|
|
||||||
shape: Shape,
|
shape: Shape,
|
||||||
|
dim: usize,
|
||||||
|
start: usize,
|
||||||
|
len: usize,
|
||||||
|
msg: &'static str,
|
||||||
},
|
},
|
||||||
|
|
||||||
// TODO this is temporary when we support arbitrary matmul
|
#[error("conv1d invalid args {msg}: inp: {inp_shape:?}, k: {k_shape:?}, pad: {padding}, stride: {stride}")]
|
||||||
#[error("temporary error where matmul doesn't support arbitrary striding {lhs_stride:?} x {rhs_stride:?}")]
|
Conv1dInvalidArgs {
|
||||||
UnexpectedStriding {
|
inp_shape: Shape,
|
||||||
lhs_stride: Vec<usize>,
|
k_shape: Shape,
|
||||||
rhs_stride: Vec<usize>,
|
padding: usize,
|
||||||
|
stride: usize,
|
||||||
|
msg: &'static str,
|
||||||
},
|
},
|
||||||
|
|
||||||
|
#[error("{op} invalid index {index} with vocab {vocab_size}")]
|
||||||
|
InvalidIndex {
|
||||||
|
op: &'static str,
|
||||||
|
index: usize,
|
||||||
|
vocab_size: usize,
|
||||||
|
},
|
||||||
|
|
||||||
|
#[error("cannot broadcast {src_shape:?} to {dst_shape:?}")]
|
||||||
|
BroadcastIncompatibleShapes { src_shape: Shape, dst_shape: Shape },
|
||||||
|
|
||||||
|
// Box indirection to avoid large variant.
|
||||||
|
#[error("{0:?}")]
|
||||||
|
MatMulUnexpectedStriding(Box<MatMulUnexpectedStriding>),
|
||||||
|
|
||||||
|
#[error("{op} only supports contiguous tensors")]
|
||||||
|
RequiresContiguous { op: &'static str },
|
||||||
|
|
||||||
|
#[error("{op} expects at least one tensor")]
|
||||||
|
OpRequiresAtLeastOneTensor { op: &'static str },
|
||||||
|
|
||||||
|
#[error("backward is not supported for {op}")]
|
||||||
|
BackwardNotSupported { op: &'static str },
|
||||||
|
|
||||||
|
// === Other Errors ===
|
||||||
|
#[error("the candle crate has not been built with cuda support")]
|
||||||
|
NotCompiledWithCudaSupport,
|
||||||
|
|
||||||
|
#[error("cannot find tensor {path}")]
|
||||||
|
CannotFindTensor { path: String },
|
||||||
|
|
||||||
|
// === Wrapped Errors ===
|
||||||
#[error(transparent)]
|
#[error(transparent)]
|
||||||
Cuda(Box<dyn std::error::Error + Send + Sync>),
|
Cuda(Box<dyn std::error::Error + Send + Sync>),
|
||||||
|
|
||||||
@ -126,22 +157,6 @@ pub enum Error {
|
|||||||
|
|
||||||
#[error("unsupported safetensor dtype {0:?}")]
|
#[error("unsupported safetensor dtype {0:?}")]
|
||||||
UnsupportedSafeTensorDtype(safetensors::Dtype),
|
UnsupportedSafeTensorDtype(safetensors::Dtype),
|
||||||
|
|
||||||
#[error("unsupported dtype {0:?} for op {1}")]
|
|
||||||
UnsupportedDTypeForOp(DType, &'static str),
|
|
||||||
|
|
||||||
#[error("cannot broadcast {src_shape:?} to {dst_shape:?}")]
|
|
||||||
BroadcastIncompatibleShapes { src_shape: Shape, dst_shape: Shape },
|
|
||||||
|
|
||||||
#[error("matmul is only supported for contiguous tensors lstride: {lhs_stride:?} rstride: {rhs_stride:?} mnk: {mnk:?}")]
|
|
||||||
MatMulNonContiguous {
|
|
||||||
lhs_stride: Vec<usize>,
|
|
||||||
rhs_stride: Vec<usize>,
|
|
||||||
mnk: (usize, usize, usize),
|
|
||||||
},
|
|
||||||
|
|
||||||
#[error("cannot find tensor {path}")]
|
|
||||||
CannotFindTensor { path: String },
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub type Result<T> = std::result::Result<T, Error>;
|
pub type Result<T> = std::result::Result<T, Error>;
|
||||||
|
@ -60,20 +60,26 @@ impl Layout {
|
|||||||
self.shape.is_fortran_contiguous(&self.stride)
|
self.shape.is_fortran_contiguous(&self.stride)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn narrow(&self, dim: usize, start: usize, length: usize) -> Result<Self> {
|
pub(crate) fn narrow(&self, dim: usize, start: usize, len: usize) -> Result<Self> {
|
||||||
let dims = self.shape().dims();
|
let dims = self.shape().dims();
|
||||||
if dim >= dims.len() {
|
if dim >= dims.len() {
|
||||||
Err(Error::UnexpectedNumberOfDims {
|
Err(Error::DimOutOfRange {
|
||||||
expected: dim + 1,
|
|
||||||
got: dims.len(),
|
|
||||||
shape: self.shape().clone(),
|
shape: self.shape().clone(),
|
||||||
|
dim: dim as i32,
|
||||||
|
op: "narrow",
|
||||||
})?
|
})?
|
||||||
}
|
}
|
||||||
if start + length > dims[dim] {
|
if start + len > dims[dim] {
|
||||||
todo!("add a proper error: out of bounds for narrow {dim} {start} {length} {dims:?}")
|
Err(Error::NarrowInvalidArgs {
|
||||||
|
shape: self.shape.clone(),
|
||||||
|
dim,
|
||||||
|
start,
|
||||||
|
len,
|
||||||
|
msg: "start + len > dim_len",
|
||||||
|
})?
|
||||||
}
|
}
|
||||||
let mut dims = dims.to_vec();
|
let mut dims = dims.to_vec();
|
||||||
dims[dim] = length;
|
dims[dim] = len;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
shape: Shape::from(dims),
|
shape: Shape::from(dims),
|
||||||
stride: self.stride.clone(),
|
stride: self.stride.clone(),
|
||||||
|
@ -194,7 +194,7 @@ impl Dim for usize {
|
|||||||
if dim >= shape.dims().len() {
|
if dim >= shape.dims().len() {
|
||||||
Err(Error::DimOutOfRange {
|
Err(Error::DimOutOfRange {
|
||||||
shape: shape.clone(),
|
shape: shape.clone(),
|
||||||
dim,
|
dim: dim as i32,
|
||||||
op,
|
op,
|
||||||
})?
|
})?
|
||||||
} else {
|
} else {
|
||||||
@ -207,7 +207,7 @@ impl Dim for usize {
|
|||||||
if dim > shape.dims().len() {
|
if dim > shape.dims().len() {
|
||||||
Err(Error::DimOutOfRange {
|
Err(Error::DimOutOfRange {
|
||||||
shape: shape.clone(),
|
shape: shape.clone(),
|
||||||
dim,
|
dim: dim as i32,
|
||||||
op,
|
op,
|
||||||
})?
|
})?
|
||||||
} else {
|
} else {
|
||||||
@ -221,30 +221,36 @@ pub enum D {
|
|||||||
Minus2,
|
Minus2,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl D {
|
||||||
|
fn out_of_range(&self, shape: &Shape, op: &'static str) -> Error {
|
||||||
|
let dim = match self {
|
||||||
|
Self::Minus1 => -1,
|
||||||
|
Self::Minus2 => -2,
|
||||||
|
};
|
||||||
|
Error::DimOutOfRange {
|
||||||
|
shape: shape.clone(),
|
||||||
|
dim,
|
||||||
|
op,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl Dim for D {
|
impl Dim for D {
|
||||||
fn to_index(&self, shape: &Shape, op: &'static str) -> Result<usize> {
|
fn to_index(&self, shape: &Shape, op: &'static str) -> Result<usize> {
|
||||||
let rank = shape.rank();
|
let rank = shape.rank();
|
||||||
match self {
|
match self {
|
||||||
Self::Minus1 if rank >= 1 => Ok(rank - 1),
|
Self::Minus1 if rank >= 1 => Ok(rank - 1),
|
||||||
Self::Minus2 if rank >= 2 => Ok(rank - 2),
|
Self::Minus2 if rank >= 2 => Ok(rank - 2),
|
||||||
_ => Err(Error::DimOutOfRange {
|
_ => Err(self.out_of_range(shape, op)),
|
||||||
shape: shape.clone(),
|
|
||||||
dim: 42, // TODO: Have an adequate error
|
|
||||||
op,
|
|
||||||
}),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn to_index_plus_one(&self, shape: &Shape, op: &'static str) -> Result<usize> {
|
fn to_index_plus_one(&self, shape: &Shape, op: &'static str) -> Result<usize> {
|
||||||
let rank = shape.rank();
|
let rank = shape.rank();
|
||||||
match self {
|
match self {
|
||||||
Self::Minus1 if rank >= 1 => Ok(rank),
|
Self::Minus1 => Ok(rank),
|
||||||
Self::Minus2 if rank >= 2 => Ok(rank - 1),
|
Self::Minus2 if rank >= 1 => Ok(rank - 1),
|
||||||
_ => Err(Error::DimOutOfRange {
|
_ => Err(self.out_of_range(shape, op)),
|
||||||
shape: shape.clone(),
|
|
||||||
dim: 42, // TODO: Have an adequate error
|
|
||||||
op,
|
|
||||||
}),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -490,7 +490,7 @@ impl Tensor {
|
|||||||
if dim >= self.dims().len() {
|
if dim >= self.dims().len() {
|
||||||
Err(Error::DimOutOfRange {
|
Err(Error::DimOutOfRange {
|
||||||
shape: self.shape().clone(),
|
shape: self.shape().clone(),
|
||||||
dim,
|
dim: dim as i32,
|
||||||
op,
|
op,
|
||||||
})?
|
})?
|
||||||
} else {
|
} else {
|
||||||
@ -509,6 +509,7 @@ impl Tensor {
|
|||||||
dim,
|
dim,
|
||||||
start,
|
start,
|
||||||
len,
|
len,
|
||||||
|
msg: "start + len > dim_len",
|
||||||
})?
|
})?
|
||||||
}
|
}
|
||||||
if start == 0 && dims[dim] == len {
|
if start == 0 && dims[dim] == len {
|
||||||
@ -576,10 +577,22 @@ impl Tensor {
|
|||||||
let (b_size, c_in, l_in) = match *self.dims() {
|
let (b_size, c_in, l_in) = match *self.dims() {
|
||||||
[b_size, c_in, l_in] => (Some(b_size), c_in, l_in),
|
[b_size, c_in, l_in] => (Some(b_size), c_in, l_in),
|
||||||
[c_in, l_in] => (None, c_in, l_in),
|
[c_in, l_in] => (None, c_in, l_in),
|
||||||
_ => todo!("proper error message"),
|
_ => Err(Error::Conv1dInvalidArgs {
|
||||||
|
inp_shape: self.shape().clone(),
|
||||||
|
k_shape: kernel.shape().clone(),
|
||||||
|
padding,
|
||||||
|
stride,
|
||||||
|
msg: "input rank is not 2 or 3",
|
||||||
|
})?,
|
||||||
};
|
};
|
||||||
if c_in != c_in_k {
|
if c_in != c_in_k {
|
||||||
todo!("proper error message")
|
Err(Error::Conv1dInvalidArgs {
|
||||||
|
inp_shape: self.shape().clone(),
|
||||||
|
k_shape: kernel.shape().clone(),
|
||||||
|
padding,
|
||||||
|
stride,
|
||||||
|
msg: "the number of in-channels on the input doesn't match the kernel size",
|
||||||
|
})?
|
||||||
}
|
}
|
||||||
let params = crate::conv::ParamsConv1D {
|
let params = crate::conv::ParamsConv1D {
|
||||||
b_size,
|
b_size,
|
||||||
|
@ -157,8 +157,9 @@ impl<'a> VarBuilder<'a> {
|
|||||||
routing,
|
routing,
|
||||||
safetensors,
|
safetensors,
|
||||||
} => {
|
} => {
|
||||||
// Unwrap or 0 just to let the proper error flow.
|
let index = routing.get(&path).ok_or_else(|| Error::CannotFindTensor {
|
||||||
let index = routing.get(&path).unwrap_or(&0);
|
path: path.to_string(),
|
||||||
|
})?;
|
||||||
safetensors[*index]
|
safetensors[*index]
|
||||||
.tensor(&path, &data.device)?
|
.tensor(&path, &data.device)?
|
||||||
.to_dtype(data.dtype)?
|
.to_dtype(data.dtype)?
|
||||||
|
Reference in New Issue
Block a user