Refactor the reduce ops in order to introduce argmin/argmax. (#212)

* Refactor the reduce ops in order to introduce argmin/argmax.

* Clippy fixes.

* Use the newly introduced argmax.

* Fix the strided case.

* Handle the non-contiguous case.
This commit is contained in:
Laurent Mazare
2023-07-21 12:41:08 +02:00
committed by GitHub
parent c60831aad4
commit 410654525f
7 changed files with 241 additions and 110 deletions

View File

@ -562,6 +562,8 @@ impl<'a> Map1 for FastReduce<'a> {
ReduceOp::Sum => "fast_sum",
ReduceOp::Min => "fast_min",
ReduceOp::Max => "fast_max",
ReduceOp::ArgMin => "fast_argmin",
ReduceOp::ArgMax => "fast_argmax",
};
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::REDUCE)?;
// SAFETY: filled in by the follow up kernel.