mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Refactor the reduce ops in order to introduce argmin/argmax. (#212)
* Refactor the reduce ops in order to introduce argmin/argmax. * Clippy fixes. * Use the newly introduced argmax. * Fix the strided case. * Handle the non-contiguous case.
This commit is contained in:
@ -562,6 +562,8 @@ impl<'a> Map1 for FastReduce<'a> {
|
||||
ReduceOp::Sum => "fast_sum",
|
||||
ReduceOp::Min => "fast_min",
|
||||
ReduceOp::Max => "fast_max",
|
||||
ReduceOp::ArgMin => "fast_argmin",
|
||||
ReduceOp::ArgMax => "fast_argmax",
|
||||
};
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::REDUCE)?;
|
||||
// SAFETY: filled in by the follow up kernel.
|
||||
|
Reference in New Issue
Block a user