Refactor the reduce ops in order to introduce argmin/argmax. (#212)

* Refactor the reduce ops in order to introduce argmin/argmax.

* Clippy fixes.

* Use the newly introduced argmax.

* Fix the strided case.

* Handle the non-contiguous case.
This commit is contained in:
Laurent Mazare
2023-07-21 12:41:08 +02:00
committed by GitHub
parent c60831aad4
commit 410654525f
7 changed files with 241 additions and 110 deletions

View File

@ -17,6 +17,20 @@ pub enum ReduceOp {
Sum,
Min,
Max,
ArgMin,
ArgMax,
}
impl ReduceOp {
pub(crate) fn name(&self) -> &'static str {
match self {
Self::ArgMax => "argmax",
Self::ArgMin => "argmin",
Self::Min => "min",
Self::Max => "max",
Self::Sum => "sum",
}
}
}
// These ops return the same type as their input type.