From 662c186fd509af12ef69ccc660607618d8afd297 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Wed, 18 Oct 2023 08:40:14 +0100 Subject: [PATCH] Better error message when overflowing in narrow. (#1119) --- candle-core/src/tensor.rs | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index e2c97af2..0f028dd8 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -615,15 +615,23 @@ impl Tensor { pub fn narrow(&self, dim: D, start: usize, len: usize) -> Result { let dims = self.dims(); let dim = dim.to_index(self.shape(), "narrow")?; - if start + len > dims[dim] { - Err(Error::NarrowInvalidArgs { - shape: self.shape().clone(), - dim, - start, - len, - msg: "start + len > dim_len", - } - .bt())? + let err = |msg| { + Err::<(), _>( + Error::NarrowInvalidArgs { + shape: self.shape().clone(), + dim, + start, + len, + msg, + } + .bt(), + ) + }; + if start > dims[dim] { + err("start > dim_len")? + } + if start.saturating_add(len) > dims[dim] { + err("start + len > dim_len")? } if start == 0 && dims[dim] == len { Ok(self.clone())