From 1f4c54493e7db0f1d39a1ca5363d1fca61fa91fe Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Sun, 21 Jan 2024 18:12:49 +0100 Subject: [PATCH] Improve arg reduce and add contiguous impl --- candle-core/benches/bench_main.rs | 2 +- candle-core/benches/benchmarks/reduce.rs | 42 +- candle-core/src/metal_backend.rs | 71 ++-- candle-metal-kernels/src/reduce.metal | 477 ++++++++++++++--------- 4 files changed, 358 insertions(+), 234 deletions(-) diff --git a/candle-core/benches/bench_main.rs b/candle-core/benches/bench_main.rs index 07668c81..88b5dfd7 100644 --- a/candle-core/benches/bench_main.rs +++ b/candle-core/benches/bench_main.rs @@ -1,4 +1,4 @@ mod benchmarks; use criterion::criterion_main; -criterion_main!(benchmarks::reduce::benches); \ No newline at end of file +criterion_main!(benchmarks::reduce::benches); diff --git a/candle-core/benches/benchmarks/reduce.rs b/candle-core/benches/benchmarks/reduce.rs index f44a1730..8216d9d0 100644 --- a/candle-core/benches/benchmarks/reduce.rs +++ b/candle-core/benches/benchmarks/reduce.rs @@ -1,19 +1,25 @@ -use candle_core::{DType, Tensor}; +use crate::benchmarks::{bench_name, device, BenchDevice}; +use candle_core::{DType, Device, Tensor}; use criterion::{black_box, criterion_group, Criterion, Throughput}; use std::time::Instant; -use crate::benchmarks::{bench_name, device, BenchDevice}; -fn run(a: &Tensor) { +fn run_sum(a: &Tensor) { a.sum(2).unwrap(); } +fn run_arg_min(a: &Tensor) { + a.argmin(2).unwrap(); +} fn criterion_benchmark(c: &mut Criterion) { + let device = device().unwrap(); + run_reduce(c, &device); + run_arg_reduce(c, &device); +} +fn run_reduce(c: &mut Criterion, device: &Device) { let b = 1; let m = 2048; let k = 2048; - let device = device().unwrap(); - let a = Tensor::rand(-1000.0f32, 1000.0f32, (b, m, k), &device).unwrap(); let flops = b * m * k * DType::F32.size_in_bytes(); @@ -24,7 +30,31 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter_custom(|iters| { let start = Instant::now(); for _i in 0..iters { - run(black_box(&a)); + run_sum(black_box(&a)); + } + device.sync().unwrap(); + start.elapsed() + }) + }); + group.finish(); +} + +fn run_arg_reduce(c: &mut Criterion, device: &Device) { + let b = 1; + let m = 2048; + let k = 2048; + + let a = Tensor::rand(-1000.0f32, 1000.0f32, (b, m, k), &device).unwrap(); + + let flops = b * m * k * DType::F32.size_in_bytes(); + + let mut group = c.benchmark_group(bench_name("arg_reduce")); + group.throughput(Throughput::Bytes(flops as u64)); + group.bench_function("iter", move |b| { + b.iter_custom(|iters| { + let start = Instant::now(); + for _i in 0..iters { + run_arg_min(black_box(&a)); } device.sync().unwrap(); start.elapsed() diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 6d5232ce..1eeb53c0 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -511,59 +511,56 @@ impl BackendStorage for MetalStorage { (ReduceOp::Sum, DType::F32) => ("fast_sum_f32", false, false), (ReduceOp::Min, DType::F32) => ("fast_min_f32", true, false), (ReduceOp::Max, DType::F32) => ("fast_max_f32", true, false), - //(ReduceOp::ArgMin, DType::F32) => ("fast_argmin_f32", true, true), - //(ReduceOp::ArgMax, DType::F32) => ("fast_argmax_f32", true, true), + (ReduceOp::ArgMin, DType::F32) => ("fast_argmin_f32", true, true), + (ReduceOp::ArgMax, DType::F32) => ("fast_argmax_f32", true, true), (ReduceOp::Sum, DType::U32) => ("fast_sum_u32", false, false), (ReduceOp::Min, DType::U32) => ("fast_min_u32", true, false), (ReduceOp::Max, DType::U32) => ("fast_max_u32", true, false), - //(ReduceOp::ArgMin, DType::U32) => ("fast_argmin_u32", true, true), - //(ReduceOp::ArgMax, DType::U32) => ("fast_argmax_u32", true, true), + (ReduceOp::ArgMin, DType::U32) => ("fast_argmin_u32", true, true), + (ReduceOp::ArgMax, DType::U32) => ("fast_argmax_u32", true, true), (ReduceOp::Sum, DType::F16) => ("fast_sum_f16", false, false), (ReduceOp::Min, DType::F16) => ("fast_min_f16", true, false), (ReduceOp::Max, DType::F16) => ("fast_max_f16", true, false), - //(ReduceOp::ArgMin, DType::F16) => ("fast_argmin_f16", true, true), - //(ReduceOp::ArgMax, DType::F16) => ("fast_argmax_f16", true, true), + (ReduceOp::ArgMin, DType::F16) => ("fast_argmin_f16", true, true), + (ReduceOp::ArgMax, DType::F16) => ("fast_argmax_f16", true, true), (ReduceOp::Sum, DType::BF16) => ("fast_sum_bf16", false, false), (ReduceOp::Min, DType::BF16) => ("fast_min_bf16", true, false), (ReduceOp::Max, DType::BF16) => ("fast_max_bf16", true, false), - //(ReduceOp::ArgMin, DType::BF16) => ("fast_argmin_bf16", true, true), - //(ReduceOp::ArgMax, DType::BF16) => ("fast_argmax_bf16", true, true), + (ReduceOp::ArgMin, DType::BF16) => ("fast_argmin_bf16", true, true), + (ReduceOp::ArgMax, DType::BF16) => ("fast_argmax_bf16", true, true), (ReduceOp::Sum, DType::I64) => ("fast_sum_i64", false, false), (ReduceOp::Min, DType::I64) => ("fast_min_i64", true, false), (ReduceOp::Max, DType::I64) => ("fast_max_i64", true, false), - //(ReduceOp::ArgMin, DType::I64) => ("fast_argmin_i64", true, true), - //(ReduceOp::ArgMax, DType::I64) => ("fast_argmax_i64", true, true), + (ReduceOp::ArgMin, DType::I64) => ("fast_argmin_i64", true, true), + (ReduceOp::ArgMax, DType::I64) => ("fast_argmax_i64", true, true), (ReduceOp::Sum, DType::U8) => ("fast_sum_u8", false, false), (ReduceOp::Min, DType::U8) => ("fast_min_u8", true, false), (ReduceOp::Max, DType::U8) => ("fast_max_u8", true, false), - //(ReduceOp::ArgMin, DType::U8) => ("fast_argmin_u8", true, true), - //(ReduceOp::ArgMax, DType::U8) => ("fast_argmax_u8", true, true), - //(k, dtype) => crate::bail!("Metal reduce op {k:?} {dtype:?} not implemented"), - _ => ("fall back to strided impl", false, false) - }; - - if name != "fall back to strided impl" { - if check_empty && layout.shape().elem_count() == 0 { - Err(crate::Error::EmptyTensor { op: "reduce" }.bt())? + (ReduceOp::ArgMin, DType::U8) => ("fast_argmin_u8", true, true), + (ReduceOp::ArgMax, DType::U8) => ("fast_argmax_u8", true, true), + (k, dtype) => { + crate::bail!("Metal contiguous reduce op {k:?} {dtype:?} not implemented") } - - - let buffer = device.new_buffer(1, self.dtype, "reduce")?; - let command_buffer = self.device.command_buffer()?; - candle_metal_kernels::call_reduce_contiguous( - &device.device, - &command_buffer, - &device.kernels, - name, - layout.shape().elem_count(), - dst_el, - &self.buffer, - layout.start_offset() * self.dtype.size_in_bytes(), - &buffer, - ) - .map_err(MetalError::from)?; - return Ok(Self::new(buffer, device, self.dtype)); + }; + if check_empty && layout.shape().elem_count() == 0 { + Err(crate::Error::EmptyTensor { op: "reduce" }.bt())? } + + let buffer = device.new_buffer(1, self.dtype, "reduce")?; + let command_buffer = self.device.command_buffer()?; + candle_metal_kernels::call_reduce_contiguous( + &device.device, + &command_buffer, + &device.kernels, + name, + layout.shape().elem_count(), + dst_el, + &self.buffer, + layout.start_offset() * self.dtype.size_in_bytes(), + &buffer, + ) + .map_err(MetalError::from)?; + return Ok(Self::new(buffer, device, self.dtype)); } for &dim_idx in sum_dims.iter() { @@ -602,7 +599,7 @@ impl BackendStorage for MetalStorage { (ReduceOp::Max, DType::U8) => ("fast_max_u8_strided", true, false), (ReduceOp::ArgMin, DType::U8) => ("fast_argmin_u8_strided", true, true), (ReduceOp::ArgMax, DType::U8) => ("fast_argmax_u8_strided", true, true), - (k, dtype) => crate::bail!("Metal reduce op {k:?} {dtype:?} not implemented"), + (k, dtype) => crate::bail!("Metal strided reduce op {k:?} {dtype:?} not implemented"), }; if check_empty && layout.shape().elem_count() == 0 { Err(crate::Error::EmptyTensor { op: "reduce" }.bt())? diff --git a/candle-metal-kernels/src/reduce.metal b/candle-metal-kernels/src/reduce.metal index d297a527..5c8963cd 100644 --- a/candle-metal-kernels/src/reduce.metal +++ b/candle-metal-kernels/src/reduce.metal @@ -19,24 +19,24 @@ METAL_FUNC uint get_strided_index( } #define impl_reduction_op(name, op, init_val) \ -template \ +template \ struct name { \ \ static constexpr constant T init = init_val; \ \ - METAL_FUNC T operator()(thread const T &a, thread const T &b) const { \ + METAL_FUNC R operator()(thread const T &a, thread const T &b) const { \ return op; \ } \ \ - METAL_FUNC T operator()(threadgroup const T &a, threadgroup const T &b) const { \ + METAL_FUNC R operator()(threadgroup const T &a, threadgroup const T &b) const { \ return op; \ } \ \ - METAL_FUNC T operator()(device const T &a, device const T &b) const { \ + METAL_FUNC R operator()(device const T &a, device const T &b) const { \ return op; \ } \ \ - METAL_FUNC T operator()(T a, T b) { \ + METAL_FUNC R operator()(T a, T b) { \ return op; \ } \ } \ @@ -45,10 +45,13 @@ impl_reduction_op(Sum, a + b, 0); impl_reduction_op(Mul, a * b, 1); impl_reduction_op(Min, a < b ? a : b, numeric_limits::max()); impl_reduction_op(Max, a > b ? a : b, numeric_limits::min()); +impl_reduction_op(ArgMin, a < b, numeric_limits::max()); +impl_reduction_op(ArgMax, a > b, numeric_limits::min()); #undef impl_reduction_op static constant constexpr int THREADGROUP_SIZE = 2048; +// Load strided elements from global memory into shared memory. template METAL_FUNC void load_from_global( constant size_t &num_dims, @@ -74,6 +77,40 @@ METAL_FUNC void load_from_global( threadgroup_barrier(mem_flags::mem_none); } +// Load strided elements from global memory into shared memory with indices. +template +METAL_FUNC void load_from_global( + constant size_t &num_dims, + constant size_t *dims, + constant size_t *strides, + constant size_t &el_to_sum_per_block, + device const T *src, + threadgroup T shared[BLOCKSIZE], + threadgroup uint shared_indices[BLOCKSIZE], + uint tid [[ thread_index_in_threadgroup ]], + uint dst_id [[ threadgroup_position_in_grid ]], + uint block_dim [[ threads_per_threadgroup ]] +) { + ArgReductionOp op; + bool notset = true; + + size_t start_idx = dst_id * el_to_sum_per_block; + size_t stop_idx = start_idx + el_to_sum_per_block; + size_t idx = start_idx + tid; + while (idx < stop_idx) { + size_t strided_i = get_strided_index(idx, num_dims, dims, strides); + if (notset || op(src[strided_i], shared[tid])) { + shared[tid] = src[strided_i]; + // Assume that the reduction takes place over the last dimension which is contiguous. + shared_indices[tid] = idx % dims[num_dims - 1]; + notset = false; + } + idx += block_dim; + } + threadgroup_barrier(mem_flags::mem_none); +} + +// Load contiguous elements from global memory into shared memory. template METAL_FUNC void load_from_global( constant size_t &num_dims, @@ -97,6 +134,45 @@ METAL_FUNC void load_from_global( threadgroup_barrier(mem_flags::mem_none); } +// Load contiguous elements from global memory into shared memory with indices. +template +METAL_FUNC void load_from_global( + constant size_t &num_dims, + constant size_t *dims, + constant size_t &el_to_sum_per_block, + device const T *src, + threadgroup T shared[BLOCKSIZE], + threadgroup uint shared_indices[BLOCKSIZE], + uint tid [[ thread_index_in_threadgroup ]], + uint dst_id [[ threadgroup_position_in_grid ]], + uint block_dim [[ threads_per_threadgroup ]] +) { + ArgReductionOp op; + bool notset = true; + + size_t start_idx = dst_id * el_to_sum_per_block; + size_t stop_idx = start_idx + el_to_sum_per_block; + size_t idx = start_idx + tid; + while (idx < stop_idx) { + if (notset || op(src[idx], shared[tid])) { + shared[tid] = src[idx]; + // Assume that the reduction takes place over the last dimension which is contiguous. + shared_indices[tid] = idx % dims[num_dims - 1]; + notset = false; + } + idx += block_dim; + } + threadgroup_barrier(mem_flags::mem_none); +} + +#define reduce_threadgroup(SIZE) \ +if (BLOCKSIZE >= SIZE) { \ + if (block_dim >= SIZE) { \ + shared[tid] = op(shared[tid], shared[tid + SIZE / 2]); \ + threadgroup_barrier(mem_flags::mem_none); \ + } \ +} + template METAL_FUNC void threadgroup_reduce( threadgroup T shared[BLOCKSIZE], @@ -104,37 +180,50 @@ METAL_FUNC void threadgroup_reduce( uint block_dim [[ threads_per_threadgroup ]] ) { ReductionOp op; - if (BLOCKSIZE >= 64) { - if (block_dim >= 64) { - shared[tid] = op(shared[tid], shared[tid + 32]); - } - } - if (BLOCKSIZE >= 32) { - if (block_dim >= 32) { - shared[tid] = op(shared[tid], shared[tid + 16]); - } - } - if (BLOCKSIZE >= 16) { - if (block_dim >= 16) { - shared[tid] = op(shared[tid], shared[tid + 8]); - } - } - if (BLOCKSIZE >= 8) { - if (block_dim >= 8) { - shared[tid] = op(shared[tid], shared[tid + 4]); - } - } - if (BLOCKSIZE >= 4) { - if (block_dim >= 4) { - shared[tid] = op(shared[tid], shared[tid + 2]); - } - } - if (BLOCKSIZE >= 2) { - if (block_dim >= 2) { - shared[tid] = op(shared[tid], shared[tid + 1]); - } - } + reduce_threadgroup(64); + reduce_threadgroup(32); + reduce_threadgroup(16); + reduce_threadgroup(8); + reduce_threadgroup(4); + reduce_threadgroup(2); } +#undef reduce_threadgroup + +#define arg_reduce_threadgroup(SIZE) \ +if (BLOCKSIZE >= SIZE) { \ + if (block_dim >= SIZE && \ + op(shared[tid], shared[tid + SIZE / 2]) \ + ) { \ + shared_indices[tid] = shared_indices[tid + SIZE / 2]; \ + shared[tid] = shared[tid + SIZE / 2]; \ + threadgroup_barrier(mem_flags::mem_none); \ + } \ +} + +template +METAL_FUNC void threadgroup_reduce( + threadgroup T shared[BLOCKSIZE], + threadgroup uint shared_indices[BLOCKSIZE], + uint tid [[thread_index_in_threadgroup]], + uint block_dim [[ threads_per_threadgroup ]] +) { + ArgReductionOp op; + arg_reduce_threadgroup(64); + arg_reduce_threadgroup(32); + arg_reduce_threadgroup(16); + arg_reduce_threadgroup(8); + arg_reduce_threadgroup(4); + arg_reduce_threadgroup(2); +} +#undef arg_reduce_threadgroup + +#define reduce_block(SIZE) \ +if (BLOCKSIZE >= SIZE) { \ + if (tid < SIZE / 2 && block_dim >= SIZE) { \ + shared[tid] = op(shared[tid], shared[tid + SIZE / 2]); \ + threadgroup_barrier(mem_flags::mem_none); \ + } \ +} \ // Inspired by "Optimizing Parallel Reduction in CUDA" by Mark Harris template< @@ -186,42 +275,20 @@ METAL_FUNC void block_reduce( ); } - if (BLOCKSIZE >= 1024) { - if (tid < 512 && block_dim >= 1024) { - shared[tid] = op(shared[tid], shared[tid + 512]); - threadgroup_barrier(mem_flags::mem_none); - } - } - if (BLOCKSIZE >= 512) { - if (tid < 256 && block_dim >= 512) { - shared[tid] = op(shared[tid], shared[tid + 256]); - threadgroup_barrier(mem_flags::mem_none); - } - } - if (BLOCKSIZE >= 256) { - if (tid < 128 && block_dim >= 256) { - shared[tid] = op(shared[tid], shared[tid + 128]); - threadgroup_barrier(mem_flags::mem_none); - } - } - if (BLOCKSIZE >= 128) { - if (tid < 64 && block_dim >= 128) { - shared[tid] = op(shared[tid], shared[tid + 64]); - threadgroup_barrier(mem_flags::mem_none); - } - } + reduce_block(1024); + reduce_block(512); + reduce_block(256); + reduce_block(128); + if (tid < 32) { threadgroup_reduce(shared, tid, block_dim); threadgroup_barrier(mem_flags::mem_none); } - if (tid == 0) { dst[dst_id] = shared[tid]; } } - -#define MAX(x, y) ((x) > (y) ? (x) : (y)) -#define MIN(x, y) ((x) < (y) ? (x) : (y)) +#undef reduce_block static constant constexpr int BLOCKSIZE = 2048; @@ -283,123 +350,146 @@ kernel void NAME##_strided( \ block_dim); \ } \ -#define ARGMIN(NAME, T, MAXVALUE) \ -kernel void NAME( \ - constant size_t &num_dims, \ - constant size_t *dims, \ - constant size_t *strides, \ - constant size_t &el_to_sum_per_block, \ - device const T *src, \ - device uint *dst, \ - uint id [[ thread_position_in_grid ]], \ - uint tid [[ thread_index_in_threadgroup ]], \ - uint dst_id [[ threadgroup_position_in_grid ]], \ - uint block_dim [[ threads_per_threadgroup ]] \ -) { \ - \ - threadgroup T shared_memory[THREADGROUP_SIZE]; \ - threadgroup uint shared_indices[THREADGROUP_SIZE]; \ - \ - shared_memory[tid] = MAXVALUE; \ - shared_indices[tid] = 0xFFFFFFFF; \ - bool notset = true; \ - /* \ - // Elements summed in this block range from dst_id * el_to_sum_per_block \ - // to (dst_id + 1) * el_to_sum_per_block. \ - */ \ - size_t start_idx = dst_id * el_to_sum_per_block; \ - size_t stop_idx = start_idx + el_to_sum_per_block; \ - size_t idx = start_idx + tid; \ - while (idx < stop_idx) { \ - /* \ - // TODO: Fast version for the contiguous case. \ - */ \ - size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \ - if (notset || src[strided_i] < shared_memory[tid]) { \ - shared_memory[tid] = src[strided_i]; \ - /* Assume that the reduction takes place over the last dimension which is contiguous. */ \ - shared_indices[tid] = idx % dims[num_dims - 1]; \ - notset = false; \ - } \ - idx += block_dim; \ - } \ - \ - threadgroup_barrier(mem_flags::mem_none); \ - \ - /* \ - // reduction in shared memory \ - */ \ - for (uint s = block_dim / 2; s > 0; s >>= 1) { \ - if (tid < s && shared_memory[tid + s] < shared_memory[tid]) { \ - shared_indices[tid] = shared_indices[tid + s]; \ - shared_memory[tid] = shared_memory[tid + s]; \ - } \ - threadgroup_barrier(mem_flags::mem_none); \ - } \ - \ - if (tid == 0){ \ - dst[dst_id] = shared_indices[0]; \ - } \ -} \ +#define arg_reduce_block(SIZE) \ +if (BLOCKSIZE >= SIZE) { \ + if (tid < SIZE / 2 \ + && block_dim >= SIZE \ + && arg_op(shared[tid], shared[tid + SIZE / 2]) \ + ) { \ + shared_indices[tid] = shared_indices[tid + SIZE / 2]; \ + shared[tid] = shared[tid + SIZE / 2]; \ + threadgroup_barrier(mem_flags::mem_none); \ + } \ +} \ + +template< + typename T, + typename ArgReductionOp, + uint BLOCKSIZE, + bool STRIDED +> +METAL_FUNC void arg_block_reduce( + constant size_t &num_dims, + constant size_t *dims, + constant size_t *strides, + constant size_t &el_to_sum_per_block, + device const T *src, + device uint *dst, + threadgroup T shared[BLOCKSIZE], + threadgroup uint shared_indices[BLOCKSIZE], + uint id [[ thread_position_in_grid ]], + uint tid [[ thread_index_in_threadgroup ]], + uint dst_id [[ threadgroup_position_in_grid ]], + uint block_dim [[ threads_per_threadgroup ]] +) { + ArgReductionOp arg_op; + + shared[tid] = ArgReductionOp::init; + shared_indices[tid] = numeric_limits::max(); + + if (STRIDED) { + load_from_global( + num_dims, + dims, + strides, + el_to_sum_per_block, + src, + shared, + shared_indices, + tid, + dst_id, + block_dim + ); + } else { + load_from_global( + num_dims, + dims, + el_to_sum_per_block, + src, + shared, + shared_indices, + tid, + dst_id, + block_dim + ); + } + arg_reduce_block(1024); + arg_reduce_block(512); + arg_reduce_block(256); + arg_reduce_block(128); + + if (tid < 32) { + threadgroup_reduce(shared, shared_indices, tid, block_dim); + threadgroup_barrier(mem_flags::mem_none); + } + + if (tid == 0) { + dst[dst_id] = shared_indices[0]; + } +} +#undef arg_reduce_block + +#define ARG_REDUCE(OP, NAME, T) \ +kernel void NAME( \ + constant size_t &num_dims, \ + constant size_t *dims, \ + constant size_t *strides, \ + constant size_t &el_to_sum_per_block, \ + device const T *src, \ + device uint *dst, \ + uint id [[ thread_position_in_grid ]], \ + uint tid [[ thread_index_in_threadgroup ]], \ + uint dst_id [[ threadgroup_position_in_grid ]], \ + uint block_dim [[ threads_per_threadgroup ]] \ +) { \ + threadgroup T shared[BLOCKSIZE]; \ + threadgroup uint shared_indices[BLOCKSIZE]; \ + arg_block_reduce, BLOCKSIZE, false>( \ + num_dims, \ + dims, \ + strides, \ + el_to_sum_per_block, \ + src, \ + dst, \ + shared, \ + shared_indices, \ + id, \ + tid, \ + dst_id, \ + block_dim); \ +} \ +kernel void NAME##_strided( \ + constant size_t &num_dims, \ + constant size_t *dims, \ + constant size_t *strides, \ + constant size_t &el_to_sum_per_block, \ + device const T *src, \ + device uint *dst, \ + uint id [[ thread_position_in_grid ]], \ + uint tid [[ thread_index_in_threadgroup ]], \ + uint dst_id [[ threadgroup_position_in_grid ]], \ + uint block_dim [[ threads_per_threadgroup ]] \ +) { \ + threadgroup T shared[BLOCKSIZE]; \ + threadgroup uint shared_indices[BLOCKSIZE]; \ + arg_block_reduce, BLOCKSIZE, true>( \ + num_dims, \ + dims, \ + strides, \ + el_to_sum_per_block, \ + src, \ + dst, \ + shared, \ + shared_indices, \ + id, \ + tid, \ + dst_id, \ + block_dim); \ +} -#define ARGMAX(NAME, T, MINVALUE) \ -kernel void NAME( \ - constant size_t &num_dims, \ - constant size_t *dims, \ - constant size_t *strides, \ - constant size_t &el_to_sum_per_block, \ - device const T *src, \ - device uint *dst, \ - uint id [[ thread_position_in_grid ]], \ - uint tid [[ thread_index_in_threadgroup ]], \ - uint dst_id [[ threadgroup_position_in_grid ]], \ - uint block_dim [[ threads_per_threadgroup ]] \ -) { \ - \ - threadgroup T shared_memory[THREADGROUP_SIZE]; \ - threadgroup uint shared_indices[THREADGROUP_SIZE]; \ - \ - shared_memory[tid] = MINVALUE; \ - shared_indices[tid] = 0xFFFFFFFF; \ - /* \ - // Elements summed in this block range from dst_id * el_to_sum_per_block \ - // to (dst_id + 1) * el_to_sum_per_block. \ - */ \ - size_t start_idx = dst_id * el_to_sum_per_block; \ - size_t stop_idx = start_idx + el_to_sum_per_block; \ - size_t idx = start_idx + tid; \ - bool notset = true; \ - while (idx < stop_idx) { \ - /* \ - // TODO: Fast version for the contiguous case. \ - */ \ - size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \ - if (notset || shared_memory[tid] < src[strided_i]) { \ - shared_memory[tid] = src[strided_i]; \ - shared_indices[tid] = idx % dims[num_dims - 1]; \ - notset = false; \ - } \ - idx += block_dim; \ - } \ - \ - threadgroup_barrier(mem_flags::mem_none); \ - \ - /* \ - // reduction in shared memory \ - */ \ - for (uint s = block_dim / 2; s > 0; s >>= 1) { \ - if (tid < s && shared_memory[tid + s] > shared_memory[tid]) { \ - shared_indices[tid] = shared_indices[tid + s]; \ - shared_memory[tid] = shared_memory[tid + s]; \ - } \ - threadgroup_barrier(mem_flags::mem_none); \ - } \ - \ - if (tid == 0){ \ - dst[dst_id] = shared_indices[0]; \ - } \ -} \ +#define MAX(x, y) ((x) > (y) ? (x) : (y)) +#define MIN(x, y) ((x) < (y) ? (x) : (y)) #define SOFTMAX(NAME, T) \ kernel void NAME( \ @@ -472,26 +562,31 @@ REDUCE(Sum, fast_sum_f32, float) REDUCE(Sum, fast_sum_u32, uint) REDUCE(Sum, fast_sum_f16, half) REDUCE(Sum, fast_sum_u8, uint8_t) + REDUCE(Mul, fast_mul_f32, float) REDUCE(Mul, fast_mul_u32, uint) REDUCE(Mul, fast_mul_f16, half) +REDUCE(Mul, fast_mul_u8, uint8_t) + REDUCE(Max, fast_max_f32, float) REDUCE(Max, fast_max_u32, uint) REDUCE(Max, fast_max_f16, half) REDUCE(Max, fast_max_u8, uint8_t) + REDUCE(Min, fast_min_f32, float) REDUCE(Min, fast_min_u32, uint) REDUCE(Min, fast_min_f16, half) REDUCE(Min, fast_min_u8, uint8_t) -ARGMIN(fast_argmin_f32_strided, float, HUGE_VALF) -ARGMIN(fast_argmin_f16_strided, half, HUGE_VALH) -ARGMIN(fast_argmin_u32_strided, uint, 0xFFFFFFFF) -ARGMIN(fast_argmin_u8_strided, uint8_t, 0xFF) -ARGMAX(fast_argmax_f32_strided, float, -HUGE_VALF) -ARGMAX(fast_argmax_f16_strided, half, -HUGE_VALH) -ARGMAX(fast_argmax_u32_strided, uint, 0) -ARGMAX(fast_argmax_u8_strided, uint8_t, 0) +ARG_REDUCE(ArgMin, fast_argmin_f32, float) +ARG_REDUCE(ArgMin, fast_argmin_f16, half) +ARG_REDUCE(ArgMin, fast_argmin_u32, uint) +ARG_REDUCE(ArgMin, fast_argmin_u8, uint8_t) + +ARG_REDUCE(ArgMax, fast_argmax_f32, float) +ARG_REDUCE(ArgMax, fast_argmax_f16, half) +ARG_REDUCE(ArgMax, fast_argmax_u32, uint) +ARG_REDUCE(ArgMax, fast_argmax_u8, uint8_t) SOFTMAX(softmax_f32, float) SOFTMAX(softmax_f16, half) @@ -502,8 +597,9 @@ REDUCE(Mul, fast_mul_i64, int64_t) REDUCE(Min, fast_min_i64, int64_t) REDUCE(Max, fast_max_i64, int64_t) -ARGMIN(fast_argmin_i64_strided, int64_t, INT_MAX) -ARGMAX(fast_argmax_i64_strided, int64_t, INT_MIN) +ARG_REDUCE(ArgMin, fast_argmin_i64, int64_t) +ARG_REDUCE(ArgMax, fast_argmax_i64, int64_t) + #endif #if __METAL_VERSION__ >= 310 @@ -512,7 +608,8 @@ REDUCE(Mul, fast_mul_bf16, bfloat) REDUCE(Max, fast_max_bf16, bfloat) REDUCE(Min, fast_min_bf16, bfloat) -ARGMIN(fast_argmin_bf16, bfloat, HUGE_VALBF) -ARGMAX(fast_argmax_bf16, bfloat, -HUGE_VALBF) +ARG_REDUCE(ArgMin, fast_argmin_bf16, bfloat) +ARG_REDUCE(ArgMax, fast_argmax_bf16, bfloat) + SOFTMAX(softmax_bf16, bfloat) #endif