Improve the error message on shape mismatch for cat. (#897)

* Improve the error message on shape mismatch for cat.

* Cosmetic tweak.
This commit is contained in:
Laurent Mazare
2023-09-19 15:09:47 +01:00
committed by GitHub
parent 06e46d7c3b
commit 4f91c8e109
4 changed files with 33 additions and 7 deletions

View File

@ -1907,6 +1907,34 @@ impl Tensor {
for arg in args {
arg.as_ref().check_dim(dim, "cat")?;
}
for (arg_idx, arg) in args.iter().enumerate() {
let arg = arg.as_ref();
if arg0.rank() != arg.rank() {
Err(Error::UnexpectedNumberOfDims {
expected: arg0.rank(),
got: arg.rank(),
shape: arg.shape().clone(),
}
.bt())?
}
for (dim_idx, (v1, v2)) in arg0
.shape()
.dims()
.iter()
.zip(arg.shape().dims().iter())
.enumerate()
{
if dim_idx != dim && v1 != v2 {
Err(Error::ShapeMismatchCat {
dim: dim_idx,
first_shape: arg0.shape().clone(),
n: arg_idx + 1,
nth_shape: arg.shape().clone(),
}
.bt())?
}
}
}
if dim == 0 {
Self::cat0(args)
} else {