Support Minus(u) for arbitrary values of u, e.g. Minus(3). (#2428)

* Support Minus(u) for arbitrary values of u, e.g. Minus(3).

* Forces u to be strictly positive.
This commit is contained in:
Laurent Mazare
2024-08-17 20:29:01 +01:00
committed by GitHub
parent b75ef051cf
commit 7cff5898ec

View File

@ -304,6 +304,7 @@ impl Dim for usize {
pub enum D {
Minus1,
Minus2,
Minus(usize),
}
impl D {
@ -311,6 +312,7 @@ impl D {
let dim = match self {
Self::Minus1 => -1,
Self::Minus2 => -2,
Self::Minus(u) => -(*u as i32),
};
Error::DimOutOfRange {
shape: shape.clone(),
@ -327,6 +329,7 @@ impl Dim for D {
match self {
Self::Minus1 if rank >= 1 => Ok(rank - 1),
Self::Minus2 if rank >= 2 => Ok(rank - 2),
Self::Minus(u) if *u > 0 && rank >= *u => Ok(rank - *u),
_ => Err(self.out_of_range(shape, op)),
}
}
@ -336,6 +339,7 @@ impl Dim for D {
match self {
Self::Minus1 => Ok(rank),
Self::Minus2 if rank >= 1 => Ok(rank - 1),
Self::Minus(u) if *u > 0 && rank + 1 >= *u => Ok(rank + 1 - *u),
_ => Err(self.out_of_range(shape, op)),
}
}