Do not panic on empty ranges. (#257)

This commit is contained in:
Laurent Mazare
2023-07-27 09:28:47 +01:00
committed by GitHub
parent 209f06d7c3
commit f291065f6c
2 changed files with 14 additions and 1 deletions

View File

@ -42,7 +42,7 @@ impl Tensor {
Bound::Excluded(n) => *n, Bound::Excluded(n) => *n,
Bound::Unbounded => dims[i], Bound::Unbounded => dims[i],
}; };
let out = x.narrow(current_dim, start, stop - start)?; let out = x.narrow(current_dim, start, stop.saturating_sub(start))?;
current_dim += 1; current_dim += 1;
out out
} }

View File

@ -58,6 +58,19 @@ fn range_index() -> Result<()> {
let result = tensor.i(..=1)?; let result = tensor.i(..=1)?;
assert_eq!(result.dims(), &[2, 3]); assert_eq!(result.dims(), &[2, 3]);
assert_eq!(result.to_vec2::<u32>()?, &[[0, 1, 2], [3, 4, 5]]); assert_eq!(result.to_vec2::<u32>()?, &[[0, 1, 2], [3, 4, 5]]);
// Empty range
let result = tensor.i(1..1)?;
assert_eq!(result.dims(), &[0, 3]);
let empty: [[u32; 3]; 0] = [];
assert_eq!(result.to_vec2::<u32>()?, &empty);
// Similar to PyTorch, allow empty ranges when the computed length is negative.
#[allow(clippy::reversed_empty_ranges)]
let result = tensor.i(1..0)?;
assert_eq!(result.dims(), &[0, 3]);
let empty: [[u32; 3]; 0] = [];
assert_eq!(result.to_vec2::<u32>()?, &empty);
Ok(()) Ok(())
} }