Improve how we check that the dims are in bounds.

This commit is contained in:
laurent
2023-06-30 09:11:00 +01:00
parent 00476d37f8
commit 19cbbc5212
3 changed files with 35 additions and 8 deletions

3
.gitignore vendored
View File

@ -15,3 +15,6 @@ Cargo.lock
*tokenizer.json
*.npz
perf.data
flamegraph.svg

View File

@ -10,6 +10,13 @@ pub enum Error {
got: DType,
},
#[error("{op}: dimension index {dim} out of range for {shape:?}")]
DimOutOfRange {
shape: Shape,
dim: usize,
op: &'static str,
},
#[error("invalid args for narrow: {shape:?}, dim: {dim}, start: {start}, len:{len}")]
NarrowInvalidArgs {
shape: Shape,

View File

@ -348,11 +348,24 @@ impl Tensor {
Ok(from_storage(storage, self.shape(), op, false))
}
fn check_dim(&self, dim: usize, op: &'static str) -> Result<()> {
if dim >= self.dims().len() {
Err(Error::DimOutOfRange {
shape: self.shape().clone(),
dim,
op,
})?
} else {
Ok(())
}
}
/// Returns a new tensor that is a narrowed version of the input, the dimension `dim`
/// ranges from `start` to `start + len`.
pub fn narrow(&self, dim: usize, start: usize, len: usize) -> Result<Self> {
let dims = self.dims();
if dim >= dims.len() || start + len > dims[dim] {
self.check_dim(dim, "narrow")?;
if start + len > dims[dim] {
Err(Error::NarrowInvalidArgs {
shape: self.shape().clone(),
dim,
@ -380,6 +393,7 @@ impl Tensor {
}
pub fn softmax(&self, dim: usize) -> Result<Self> {
self.check_dim(dim, "softmax")?;
// TODO: unify the two branches.
if self.device().is_cuda() {
// We do not have a cuda kernel for divide_by_sum_over_dim so split
@ -402,6 +416,9 @@ impl Tensor {
}
pub fn sum(&self, sum_dims: &[usize]) -> Result<Self> {
for &dim in sum_dims {
self.check_dim(dim, "sum")?;
}
let storage = self.storage.sum(self.layout(), sum_dims)?;
let op = if self.track_op() {
Some(Op::Sum(self.clone(), sum_dims.to_vec()))
@ -645,6 +662,8 @@ impl Tensor {
} else {
let start_dim = start_dim.unwrap_or(0);
let end_dim = end_dim.unwrap_or_else(|| self.rank() - 1);
self.check_dim(start_dim, "flatten")?;
self.check_dim(end_dim, "flatten")?;
if start_dim < end_dim {
let dims = self.dims();
let mut dst_dims = dims[..start_dim].to_vec();
@ -689,6 +708,8 @@ impl Tensor {
/// Returns a tensor that is a transposed version of the input, the given dimensions are
/// swapped.
pub fn transpose(&self, dim1: usize, dim2: usize) -> Result<Tensor> {
self.check_dim(dim1, "transpose")?;
self.check_dim(dim2, "transpose")?;
let op = if self.track_op() {
Some(Op::Transpose(self.clone(), dim1, dim2))
} else {
@ -876,6 +897,7 @@ impl Tensor {
// The PyTorch semantics are to return the same tensor if the target dimension
// does not have a size of 1.
let dims = self.dims();
self.check_dim(index, "squeeze")?;
if dims[index] == 1 {
let mut dims = dims.to_vec();
dims.remove(index);
@ -910,13 +932,8 @@ impl Tensor {
if args.len() == 1 {
return Ok(arg0.clone());
}
let rank = arg0.rank();
if dim >= rank {
return Err(Error::UnexpectedNumberOfDims {
expected: (dim + 1),
got: rank,
shape: arg0.shape().clone(),
});
for arg in args {
arg.as_ref().check_dim(dim, "cat")?;
}
if dim == 0 {
Self::cat0(args)