mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Support Minus(u) for arbitrary values of u, e.g. Minus(3). (#2428)
* Support Minus(u) for arbitrary values of u, e.g. Minus(3). * Forces u to be strictly positive.
This commit is contained in:
@ -304,6 +304,7 @@ impl Dim for usize {
|
||||
pub enum D {
|
||||
Minus1,
|
||||
Minus2,
|
||||
Minus(usize),
|
||||
}
|
||||
|
||||
impl D {
|
||||
@ -311,6 +312,7 @@ impl D {
|
||||
let dim = match self {
|
||||
Self::Minus1 => -1,
|
||||
Self::Minus2 => -2,
|
||||
Self::Minus(u) => -(*u as i32),
|
||||
};
|
||||
Error::DimOutOfRange {
|
||||
shape: shape.clone(),
|
||||
@ -327,6 +329,7 @@ impl Dim for D {
|
||||
match self {
|
||||
Self::Minus1 if rank >= 1 => Ok(rank - 1),
|
||||
Self::Minus2 if rank >= 2 => Ok(rank - 2),
|
||||
Self::Minus(u) if *u > 0 && rank >= *u => Ok(rank - *u),
|
||||
_ => Err(self.out_of_range(shape, op)),
|
||||
}
|
||||
}
|
||||
@ -336,6 +339,7 @@ impl Dim for D {
|
||||
match self {
|
||||
Self::Minus1 => Ok(rank),
|
||||
Self::Minus2 if rank >= 1 => Ok(rank - 1),
|
||||
Self::Minus(u) if *u > 0 && rank + 1 >= *u => Ok(rank + 1 - *u),
|
||||
_ => Err(self.out_of_range(shape, op)),
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user