Better error message when overflowing in narrow. (#1119)

This commit is contained in:
Laurent Mazare
2023-10-18 08:40:14 +01:00
committed by GitHub
parent 2cd745a97c
commit 662c186fd5

View File

@ -615,15 +615,23 @@ impl Tensor {
pub fn narrow<D: Dim>(&self, dim: D, start: usize, len: usize) -> Result<Self> {
let dims = self.dims();
let dim = dim.to_index(self.shape(), "narrow")?;
if start + len > dims[dim] {
Err(Error::NarrowInvalidArgs {
shape: self.shape().clone(),
dim,
start,
len,
msg: "start + len > dim_len",
}
.bt())?
let err = |msg| {
Err::<(), _>(
Error::NarrowInvalidArgs {
shape: self.shape().clone(),
dim,
start,
len,
msg,
}
.bt(),
)
};
if start > dims[dim] {
err("start > dim_len")?
}
if start.saturating_add(len) > dims[dim] {
err("start + len > dim_len")?
}
if start == 0 && dims[dim] == len {
Ok(self.clone())