From 19cbbc5212ff31092dc46286c5977fddc2327674 Mon Sep 17 00:00:00 2001 From: laurent Date: Fri, 30 Jun 2023 09:11:00 +0100 Subject: [PATCH] Improve how we check that the dims are in bounds. --- .gitignore | 3 +++ candle-core/src/error.rs | 7 +++++++ candle-core/src/tensor.rs | 33 +++++++++++++++++++++++++-------- 3 files changed, 35 insertions(+), 8 deletions(-) diff --git a/.gitignore b/.gitignore index 86bf7d2f..33593c9b 100644 --- a/.gitignore +++ b/.gitignore @@ -15,3 +15,6 @@ Cargo.lock *tokenizer.json *.npz + +perf.data +flamegraph.svg diff --git a/candle-core/src/error.rs b/candle-core/src/error.rs index 637fd8b7..341fc151 100644 --- a/candle-core/src/error.rs +++ b/candle-core/src/error.rs @@ -10,6 +10,13 @@ pub enum Error { got: DType, }, + #[error("{op}: dimension index {dim} out of range for {shape:?}")] + DimOutOfRange { + shape: Shape, + dim: usize, + op: &'static str, + }, + #[error("invalid args for narrow: {shape:?}, dim: {dim}, start: {start}, len:{len}")] NarrowInvalidArgs { shape: Shape, diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 2f05094b..a468d879 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -348,11 +348,24 @@ impl Tensor { Ok(from_storage(storage, self.shape(), op, false)) } + fn check_dim(&self, dim: usize, op: &'static str) -> Result<()> { + if dim >= self.dims().len() { + Err(Error::DimOutOfRange { + shape: self.shape().clone(), + dim, + op, + })? + } else { + Ok(()) + } + } + /// Returns a new tensor that is a narrowed version of the input, the dimension `dim` /// ranges from `start` to `start + len`. pub fn narrow(&self, dim: usize, start: usize, len: usize) -> Result { let dims = self.dims(); - if dim >= dims.len() || start + len > dims[dim] { + self.check_dim(dim, "narrow")?; + if start + len > dims[dim] { Err(Error::NarrowInvalidArgs { shape: self.shape().clone(), dim, @@ -380,6 +393,7 @@ impl Tensor { } pub fn softmax(&self, dim: usize) -> Result { + self.check_dim(dim, "softmax")?; // TODO: unify the two branches. if self.device().is_cuda() { // We do not have a cuda kernel for divide_by_sum_over_dim so split @@ -402,6 +416,9 @@ impl Tensor { } pub fn sum(&self, sum_dims: &[usize]) -> Result { + for &dim in sum_dims { + self.check_dim(dim, "sum")?; + } let storage = self.storage.sum(self.layout(), sum_dims)?; let op = if self.track_op() { Some(Op::Sum(self.clone(), sum_dims.to_vec())) @@ -645,6 +662,8 @@ impl Tensor { } else { let start_dim = start_dim.unwrap_or(0); let end_dim = end_dim.unwrap_or_else(|| self.rank() - 1); + self.check_dim(start_dim, "flatten")?; + self.check_dim(end_dim, "flatten")?; if start_dim < end_dim { let dims = self.dims(); let mut dst_dims = dims[..start_dim].to_vec(); @@ -689,6 +708,8 @@ impl Tensor { /// Returns a tensor that is a transposed version of the input, the given dimensions are /// swapped. pub fn transpose(&self, dim1: usize, dim2: usize) -> Result { + self.check_dim(dim1, "transpose")?; + self.check_dim(dim2, "transpose")?; let op = if self.track_op() { Some(Op::Transpose(self.clone(), dim1, dim2)) } else { @@ -876,6 +897,7 @@ impl Tensor { // The PyTorch semantics are to return the same tensor if the target dimension // does not have a size of 1. let dims = self.dims(); + self.check_dim(index, "squeeze")?; if dims[index] == 1 { let mut dims = dims.to_vec(); dims.remove(index); @@ -910,13 +932,8 @@ impl Tensor { if args.len() == 1 { return Ok(arg0.clone()); } - let rank = arg0.rank(); - if dim >= rank { - return Err(Error::UnexpectedNumberOfDims { - expected: (dim + 1), - got: rank, - shape: arg0.shape().clone(), - }); + for arg in args { + arg.as_ref().check_dim(dim, "cat")?; } if dim == 0 { Self::cat0(args)