From 8babfe0411acfbee867c6aeebe2077d5b91ef27a Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Tue, 30 Jan 2024 14:12:57 +0100 Subject: [PATCH] Fixed all bugs. Improved code quality. Added tests. --- candle-core/benches/bench_main.rs | 8 +- candle-core/benches/benchmarks/reduce.rs | 99 +- candle-metal-kernels/src/reduce.metal | 1041 ++++++++++----------- candle-metal-kernels/src/reduce_old.metal | 346 +++++++ candle-metal-kernels/src/tests.rs | 129 ++- 5 files changed, 1062 insertions(+), 561 deletions(-) create mode 100644 candle-metal-kernels/src/reduce_old.metal diff --git a/candle-core/benches/bench_main.rs b/candle-core/benches/bench_main.rs index 7f4ed986..672ed0b2 100644 --- a/candle-core/benches/bench_main.rs +++ b/candle-core/benches/bench_main.rs @@ -3,9 +3,9 @@ mod benchmarks; use criterion::criterion_main; criterion_main!( - benchmarks::affine::benches, - benchmarks::matmul::benches, - benchmarks::random::benches, + //benchmarks::affine::benches, + //benchmarks::matmul::benches, + //benchmarks::random::benches, benchmarks::reduce::benches, - benchmarks::where_cond::benches + //benchmarks::where_cond::benches ); diff --git a/candle-core/benches/benchmarks/reduce.rs b/candle-core/benches/benchmarks/reduce.rs index 23ef4f1b..7654a02c 100644 --- a/candle-core/benches/benchmarks/reduce.rs +++ b/candle-core/benches/benchmarks/reduce.rs @@ -61,13 +61,21 @@ fn criterion_benchmark(c: &mut Criterion) { run_softmax(c, &device, (f16::from_f32(lo), f16::from_f32(up))); run_softmax(c, &device, (bf16::from_f32(lo), bf16::from_f32(up))); - run_reduce(c, &device, (lo, up)); - run_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up))); - run_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up))); + run_reduce(c, &device, (lo, up), false); + run_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), false); + run_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), false); - run_arg_reduce(c, &device, (lo, up)); - run_arg_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up))); - run_arg_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up))); + run_arg_reduce(c, &device, (lo, up), false); + run_arg_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), false); + run_arg_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), false); + + run_reduce(c, &device, (lo, up), true); + run_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), true); + run_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), true); + + run_arg_reduce(c, &device, (lo, up), true); + run_arg_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), true); + run_arg_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), true); } } @@ -89,6 +97,7 @@ fn run_softmax(c: &mut Criterion, device: &Device, ( DType::BF16 => "softmax_bf16", _ => "softmax", }; + softmax(&a).unwrap(); let mut group = c.benchmark_group(device.bench_name(name)); group.throughput(Throughput::Bytes(flops as u64)); @@ -105,19 +114,49 @@ fn run_softmax(c: &mut Criterion, device: &Device, ( group.finish(); } -fn run_reduce(c: &mut Criterion, device: &Device, (lo, up): (T, T)) { +fn run_reduce( + c: &mut Criterion, + device: &Device, + (lo, up): (T, T), + strided: bool, +) { let b = 1; let m = 1024; let k = 1024; - let a = Tensor::rand(lo, up, (b, m, k), &device).unwrap(); + let a = if strided { + Tensor::rand(lo, up, (b, m, k), &device) + .unwrap() + .transpose(0, 2) + .unwrap() + } else { + Tensor::rand(lo, up, (b, m, k), &device).unwrap() + }; let flops = b * m * k * T::DTYPE.size_in_bytes(); let name = match T::DTYPE { - DType::F32 => "reduce_f32", - DType::F16 => "reduce_f16", - DType::BF16 => "reduce_bf16", + DType::F32 => { + if strided { + "reduce_f32_strided" + } else { + "reduce_f32" + } + } + DType::F16 => { + if strided { + "reduce_f16_strided" + } else { + "reduce_f16" + } + } + DType::BF16 => { + if strided { + "reduce_bf16_strided" + } else { + "reduce_bf16" + } + } _ => "reduce", }; @@ -140,20 +179,46 @@ fn run_arg_reduce( c: &mut Criterion, device: &Device, (lo, up): (T, T), + strided: bool, ) { let b = 1; let m = 1024; let k = 1024; - let a = Tensor::rand(lo, up, (b, m, k), &device).unwrap(); + let a = if strided { + Tensor::rand(lo, up, (b, m, k), &device) + .unwrap() + .transpose(0, 2) + .unwrap() + } else { + Tensor::rand(lo, up, (b, m, k), &device).unwrap() + }; - let flops = b * m * k * T::DTYPE.size_in_bytes(); + let flops = b * m * k * (DType::U32.size_in_bytes() + T::DTYPE.size_in_bytes()); let name = match T::DTYPE { - DType::F32 => "arg_reduce_f32", - DType::F16 => "arg_reduce_f16", - DType::BF16 => "arg_reduce_bf16", - _ => "reduce", + DType::F32 => { + if strided { + "arg_reduce_f32_strided" + } else { + "arg_reduce_f32" + } + } + DType::F16 => { + if strided { + "arg_reduce_f16_strided" + } else { + "arg_reduce_f16" + } + } + DType::BF16 => { + if strided { + "arg_reduce_bf16_strided" + } else { + "arg_reduce_bf16" + } + } + _ => "unknown", }; let mut group = c.benchmark_group(device.bench_name(name)); diff --git a/candle-metal-kernels/src/reduce.metal b/candle-metal-kernels/src/reduce.metal index b494c191..06e2438e 100644 --- a/candle-metal-kernels/src/reduce.metal +++ b/candle-metal-kernels/src/reduce.metal @@ -2,6 +2,9 @@ #include using namespace metal; +// TODO: Load multiple values per thread to improve memory bandwidth utilization +// static constant constexpr uint VALUES_PER_THREAD = 1; + METAL_FUNC uint get_strided_index( uint idx, constant const size_t &num_dims, @@ -18,647 +21,637 @@ METAL_FUNC uint get_strided_index( return strided_i; } -#define impl_reduction_op(name, op, init_val) \ -template \ +template +struct Indexed { + uint i; + V val; + typedef V type; + + constexpr Indexed() thread = default; + constexpr Indexed() threadgroup = default; + constexpr Indexed() device = default; + constexpr Indexed() constant = default; + + constexpr Indexed(uint _i, V _val) : i(_i), val(_val) {} + + template >::type> + constexpr Indexed(uint _i, U _val) : i(_i), val(static_cast(_val)) {} + + template + constexpr Indexed(const thread Indexed &iv): Indexed(iv.i, iv.val) {} + + template + constexpr Indexed(const threadgroup Indexed &iv): Indexed(iv.i, iv.val) {} + + Indexed operator=(const thread Indexed &iv) thread { + this->i = iv.i; + this->val = iv.val; + return *this; + } + Indexed operator=(const thread Indexed &iv) threadgroup { + this->i = iv.i; + this->val = iv.val; + return *this; + } +}; + +template +constexpr METAL_FUNC bool operator<(Indexed lhs, Indexed rhs) { + return lhs.val < rhs.val || (lhs.val == rhs.val && lhs.i < rhs.i); +} + +template +constexpr METAL_FUNC bool operator>(Indexed lhs, Indexed rhs) { + return lhs.val > rhs.val || (lhs.val == rhs.val && lhs.i > rhs.i); +} + +template +struct _numeric_limits_impl> { + static constexpr Indexed lowest() { + return Indexed(0, numeric_limits::lowest()); + } + + static constexpr Indexed max() { + return Indexed(0, numeric_limits::max()); + } +}; + +#if defined(__HAVE_BFLOAT__) +// Metal does not have simd_shuffle_down for bfloat16 +// TODO: Check if volatile threadgroup memory reduction is faster than simd_shuffle_down for bfloat +bfloat simd_shuffle_down(bfloat value, ushort delta) { + return static_cast(__metal_simd_shuffle_down(static_cast(value), delta)); +} +#endif + +template +Indexed simd_shuffle_down(Indexed iv, ushort delta) { + return Indexed( + simd_shuffle_down(iv.i, delta), + simd_shuffle_down(iv.val, delta) + ); +} + +#define impl_reduction_op_helper(name, op, init_val, __result_type__) \ +template \ struct name { \ - \ - static constexpr constant T init = init_val; \ - \ - METAL_FUNC R operator()(thread const T &a, thread const T &b) const { \ - return op; \ + static constexpr T init() { \ + return init_val; \ } \ - \ - METAL_FUNC R operator()(threadgroup const T &a, threadgroup const T &b) const { \ - return op; \ - } \ - \ - METAL_FUNC R operator()(device const T &a, device const T &b) const { \ - return op; \ - } \ - \ METAL_FUNC R operator()(T a, T b) { \ return op; \ } \ + METAL_FUNC R operator()(thread const T& a, thread const T& b) const { \ + return op; \ + } \ + METAL_FUNC R operator()(threadgroup const T& a, threadgroup const T& b) const { \ + return op; \ + } \ } \ +#define impl_reduction_op(name, op, init_val) \ +impl_reduction_op_helper(name, op, init_val, T); + +#define impl_arg_reduction_op(name, op, init_val) \ +impl_reduction_op_helper(name, op, init_val, tuple>); + impl_reduction_op(Sum, a + b, 0); impl_reduction_op(Mul, a * b, 1); impl_reduction_op(Min, a < b ? a : b, numeric_limits::max()); -impl_reduction_op(Max, a > b ? a : b, numeric_limits::min()); -impl_reduction_op(ArgMin, a < b, numeric_limits::max()); -impl_reduction_op(ArgMax, a > b, numeric_limits::min()); +impl_reduction_op(Max, a > b ? a : b, numeric_limits::lowest()); #undef impl_reduction_op -static constant constexpr int THREADGROUP_SIZE = 2048; +// These are used when loading elements from global memory into shared memory. +// They let us use the same code for both indexed and non-indexed types. +template +METAL_FUNC T apply_operator(Op op, size_t _idx, T a, U b) { + return op(a, static_cast(b)); +} -// Load strided elements from global memory into shared memory. -template -METAL_FUNC void load_from_global( +template +METAL_FUNC Indexed apply_operator(Op op, size_t idx, Indexed a, U b) { + return op(a, Indexed(idx, b)); +} + +// Load elements from global memory into shared memory. +// Handles both indexed and non-indexed types by using apply_operator. +template< + typename T, + typename R, + typename ReductionOp, + ushort BLOCKSIZE, + bool STRIDED = false +> +METAL_FUNC R load_from_global( + R value, + constant size_t &num_elements, constant size_t &num_dims, constant size_t *dims, constant size_t *strides, constant size_t &el_to_sum_per_block, - device const T *src, - threadgroup T shared[BLOCKSIZE], - uint tid [[ thread_index_in_threadgroup ]], - uint dst_id [[ threadgroup_position_in_grid ]], - uint block_dim [[ threads_per_threadgroup ]] + const device T *src, + const ushort offset, + threadgroup R shared[BLOCKSIZE], + const ushort tid ) { ReductionOp op; - size_t start_idx = dst_id * el_to_sum_per_block; - size_t stop_idx = start_idx + el_to_sum_per_block; - size_t idx = start_idx + tid; - while (idx < stop_idx) { - size_t strided_i = get_strided_index(idx, num_dims, dims, strides); - shared[tid] = op(shared[tid], src[strided_i]); - idx += block_dim; - } - threadgroup_barrier(mem_flags::mem_threadgroup); -} + size_t stop_idx = offset + el_to_sum_per_block; + size_t idx = offset + tid; -// Load strided elements from global memory into shared memory with indices. -template -METAL_FUNC void load_from_global( - constant size_t &num_dims, - constant size_t *dims, - constant size_t *strides, - constant size_t &el_to_sum_per_block, - device const T *src, - threadgroup T shared[BLOCKSIZE], - threadgroup uint shared_indices[BLOCKSIZE], - uint tid [[ thread_index_in_threadgroup ]], - uint dst_id [[ threadgroup_position_in_grid ]], - uint block_dim [[ threads_per_threadgroup ]] -) { - ArgReductionOp op; - bool notset = true; - - size_t start_idx = dst_id * el_to_sum_per_block; - size_t stop_idx = start_idx + el_to_sum_per_block; - size_t idx = start_idx + tid; while (idx < stop_idx) { - size_t strided_i = get_strided_index(idx, num_dims, dims, strides); - if (notset || op(src[strided_i], shared[tid])) { - shared[tid] = src[strided_i]; - // Assume that the reduction takes place over the last dimension which is contiguous. - shared_indices[tid] = idx % dims[num_dims - 1]; - notset = false; + if (STRIDED) { + idx = get_strided_index(idx, num_dims, dims, strides); } - idx += block_dim; + value = apply_operator(op, idx, value, src[idx]); + idx += BLOCKSIZE; } - threadgroup_barrier(mem_flags::mem_threadgroup); + return value; } -// Load contiguous elements from global memory into shared memory. -template -METAL_FUNC void load_from_global( - constant size_t &num_dims, - constant size_t *dims, + +// Convenience function for when we don't need to sum over multiple dimensions. +template< + typename T, + typename R, + typename ReductionOp, + ushort BLOCKSIZE +> +METAL_FUNC R load_from_global( + R value, + constant size_t &num_elements, constant size_t &el_to_sum_per_block, - device const T *src, + const device T *src, + const size_t offset, + threadgroup R shared[BLOCKSIZE], + const ushort tid +) { + return load_from_global( + value, + num_elements, + // Dummy values for num_dims, dims, and strides + num_elements, + nullptr, + nullptr, + // end dummy values + el_to_sum_per_block, + src, + offset, + shared, + tid + ); +} + +// Since we are using simd_shuffle_down with a BLOCKSIZE guard we don't need any barriers. +template +METAL_FUNC T simdgroup_reduce(T value) { + ReductionOp op; + if (BLOCKSIZE >= 32) value = op(value, simd_shuffle_down(value, 16)); + if (BLOCKSIZE >= 16) value = op(value, simd_shuffle_down(value, 8)); + if (BLOCKSIZE >= 8) value = op(value, simd_shuffle_down(value, 4)); + if (BLOCKSIZE >= 4) value = op(value, simd_shuffle_down(value, 2)); + if (BLOCKSIZE >= 2) value = op(value, simd_shuffle_down(value, 1)); + return value; +} + +template< + typename ReductionOp, + ushort BLOCKSIZE, + typename T +> +METAL_FUNC T threadgroup_reduce( threadgroup T shared[BLOCKSIZE], - uint tid [[ thread_index_in_threadgroup ]], - uint dst_id [[ threadgroup_position_in_grid ]], - uint block_dim [[ threads_per_threadgroup ]] + ushort tid [[ thread_index_in_threadgroup ]] ) { ReductionOp op; - size_t start_idx = dst_id * el_to_sum_per_block; - size_t stop_idx = start_idx + el_to_sum_per_block; - size_t idx = start_idx + tid; - while (idx < stop_idx) { - shared[tid] = op(shared[tid], src[idx]); - idx += block_dim; - } - threadgroup_barrier(mem_flags::mem_threadgroup); -} - -// Load contiguous elements from global memory into shared memory with indices. -template -METAL_FUNC void load_from_global( - constant size_t &num_dims, - constant size_t *dims, - constant size_t &el_to_sum_per_block, - device const T *src, - threadgroup T shared[BLOCKSIZE], - threadgroup uint shared_indices[BLOCKSIZE], - uint tid [[ thread_index_in_threadgroup ]], - uint dst_id [[ threadgroup_position_in_grid ]], - uint block_dim [[ threads_per_threadgroup ]] -) { - ArgReductionOp op; - bool notset = true; - - size_t start_idx = dst_id * el_to_sum_per_block; - size_t stop_idx = start_idx + el_to_sum_per_block; - size_t idx = start_idx + tid; - while (idx < stop_idx) { - if (notset || op(src[idx], shared[tid])) { - shared[tid] = src[idx]; - // Assume that the reduction takes place over the last dimension which is contiguous. - shared_indices[tid] = idx % dims[num_dims - 1]; - notset = false; + // Fully unrolled reduction loop from BLOCKSIZE down to 64. + #pragma clang loop unroll(full) + for (uint s = BLOCKSIZE / 2; s >= 64; s >>= 1) { + if (tid < s) { + shared[tid] = op(shared[tid], shared[tid + s]); } - idx += block_dim; + threadgroup_barrier(mem_flags::mem_none); } - threadgroup_barrier(mem_flags::mem_threadgroup); -} -#define reduce_threadgroup(SIZE) \ -if (BLOCKSIZE >= SIZE) { \ - if (block_dim >= SIZE) { \ - shared[tid] = op(shared[tid], shared[tid + SIZE / 2]); \ - } \ - threadgroup_barrier(mem_flags::mem_threadgroup); \ -} + if (tid < 32) { + // Last shared memory reduce can be done without tid < s check. + if (BLOCKSIZE >= 64) { + shared[tid] = op(shared[tid], shared[tid + 32]); + simdgroup_barrier(mem_flags::mem_none); + } + // Remaining 32 threads can be reduced with simdgroup_reduce. + shared[tid] = simdgroup_reduce(shared[tid]); + } -template -METAL_FUNC void threadgroup_reduce( - threadgroup T shared[BLOCKSIZE], - uint tid [[thread_index_in_threadgroup]], - uint block_dim [[ threads_per_threadgroup ]] -) { - ReductionOp op; - reduce_threadgroup(64); - reduce_threadgroup(32); - reduce_threadgroup(16); - reduce_threadgroup(8); - reduce_threadgroup(4); - reduce_threadgroup(2); + return shared[tid]; } -#undef reduce_threadgroup - -#define arg_reduce_threadgroup(SIZE) \ -if (BLOCKSIZE >= SIZE) { \ - if (block_dim >= SIZE && \ - op(shared[tid], shared[tid + SIZE / 2]) \ - ) { \ - shared_indices[tid] = shared_indices[tid + SIZE / 2]; \ - shared[tid] = shared[tid + SIZE / 2]; \ - } \ - threadgroup_barrier(mem_flags::mem_threadgroup); \ -} - -template -METAL_FUNC void threadgroup_reduce( - threadgroup T shared[BLOCKSIZE], - threadgroup uint shared_indices[BLOCKSIZE], - uint tid [[thread_index_in_threadgroup]], - uint block_dim [[ threads_per_threadgroup ]] -) { - ArgReductionOp op; - arg_reduce_threadgroup(64); - arg_reduce_threadgroup(32); - arg_reduce_threadgroup(16); - arg_reduce_threadgroup(8); - arg_reduce_threadgroup(4); - arg_reduce_threadgroup(2); -} -#undef arg_reduce_threadgroup - -#define reduce_block(SIZE) \ -if (BLOCKSIZE >= SIZE) { \ - if (tid < SIZE / 2 && block_dim >= SIZE) { \ - shared[tid] = op(shared[tid], shared[tid + SIZE / 2]); \ - } \ - threadgroup_barrier(mem_flags::mem_threadgroup); \ -} \ // Inspired by "Optimizing Parallel Reduction in CUDA" by Mark Harris template< typename T, + typename R, typename ReductionOp, - uint BLOCKSIZE, - bool STRIDED + ushort BLOCKSIZE, + bool STRIDED = false > -METAL_FUNC void block_reduce( +METAL_FUNC void reduce( constant size_t &num_dims, constant size_t *dims, constant size_t *strides, constant size_t &el_to_sum_per_block, device const T *src, - device T *dst, - constant uint &num_elements, + device R *dst, + constant size_t &num_elements, threadgroup T shared[BLOCKSIZE], - uint id [[ thread_position_in_grid ]], - uint tid [[ thread_index_in_threadgroup ]], - uint dst_id [[ threadgroup_position_in_grid ]], - uint block_dim [[ threads_per_threadgroup ]] + ushort tid [[ thread_index_in_threadgroup ]], + ushort dst_id [[ threadgroup_position_in_grid ]] ) { - ReductionOp op; + // Initialize shared memory for current thread to correct value for reduction operation + shared[tid] = ReductionOp::init(); - shared[tid] = ReductionOp::init; + // Calcluate offset for the threadgroup of current thread + ushort offset = dst_id * el_to_sum_per_block; + R initial = ReductionOp::init(); + // Load with reduction from global memory into shared memory + shared[tid] = load_from_global( + initial, + num_elements, + num_dims, + dims, + strides, + el_to_sum_per_block, + src, + offset, + shared, + tid + ); + // Threadgroup barrier is needed to ensure that all threads have written to shared memory + // Memory space is not shared between threadgroups so we can use the mem_none flag for all threadgroup barriers. + threadgroup_barrier(mem_flags::mem_none); - if (STRIDED) { - load_from_global( - num_dims, - dims, - strides, - el_to_sum_per_block, - src, - shared, - tid, - dst_id, - block_dim - ); - } else { - load_from_global( - num_dims, - dims, - el_to_sum_per_block, - src, - shared, - tid, - dst_id, - block_dim - ); - } + // Complete reduction + R value = threadgroup_reduce(shared, tid); - reduce_block(1024); - reduce_block(512); - reduce_block(256); - reduce_block(128); - - if (tid < 32) { - threadgroup_reduce(shared, tid, block_dim); - } - threadgroup_barrier(mem_flags::mem_threadgroup); - if (tid == 0) { - dst[dst_id] = shared[tid]; - } + if (tid == 0) dst[dst_id] = value; } -#undef reduce_block -static constant constexpr int BLOCKSIZE = 2048; -#define REDUCE(OP, NAME, T) \ -kernel void NAME( \ - constant size_t &num_dims, \ - constant size_t *dims, \ - constant size_t *strides, \ - constant size_t &el_to_sum_per_block, \ - device const T *src, \ - device T *dst, \ - constant uint &num_elements, \ - uint id [[ thread_position_in_grid ]], \ - uint tid [[ thread_index_in_threadgroup ]], \ - uint dst_id [[ threadgroup_position_in_grid ]], \ - uint block_dim [[ threads_per_threadgroup ]] \ -) { \ - threadgroup T shared[BLOCKSIZE]; \ - block_reduce, BLOCKSIZE, false>( \ - num_dims, \ - dims, \ - strides, \ - el_to_sum_per_block, \ - src, \ - dst, \ - num_elements, \ - shared, \ - id, \ - tid, \ - dst_id, \ - block_dim); \ -} \ -kernel void NAME##_strided( \ - constant size_t &num_dims, \ - constant size_t *dims, \ - constant size_t *strides, \ - constant size_t &el_to_sum_per_block, \ - device const T *src, \ - device T *dst, \ - constant uint &num_elements, \ - uint id [[ thread_position_in_grid ]], \ - uint tid [[ thread_index_in_threadgroup ]], \ - uint dst_id [[ threadgroup_position_in_grid ]], \ - uint block_dim [[ threads_per_threadgroup ]] \ -) { \ - threadgroup T shared[BLOCKSIZE]; \ - block_reduce, BLOCKSIZE, false>( \ - num_dims, \ - dims, \ - strides, \ - el_to_sum_per_block, \ - src, \ - dst, \ - num_elements, \ - shared, \ - id, \ - tid, \ - dst_id, \ - block_dim); \ -} \ +#define reduce_case(OP, T, R, N) \ +case N: { \ + threadgroup R shared[N]; \ + reduce, N, STRIDED>( \ + num_dims, \ + dims, \ + strides, \ + el_to_sum_per_block, \ + src, \ + dst, \ + num_elements, \ + shared, \ + tid, \ + dst_id); \ + break; \ +} -#define arg_reduce_block(SIZE) \ -if (BLOCKSIZE >= SIZE) { \ - if (tid < SIZE / 2 \ - && block_dim >= SIZE \ - && arg_op(shared[tid], shared[tid + SIZE / 2]) \ - ) { \ - shared_indices[tid] = shared_indices[tid + SIZE / 2]; \ - shared[tid] = shared[tid + SIZE / 2]; \ - } \ - threadgroup_barrier(mem_flags::mem_threadgroup); \ -} \ +#define impl_reduce(OP, NAME, T) \ +kernel void NAME( \ + constant size_t &num_dims, \ + constant size_t &el_to_sum_per_block, \ + device const T *src, \ + device T *dst, \ + constant size_t &num_elements, \ + ushort tid [[ thread_index_in_threadgroup ]], \ + ushort dst_id [[ threadgroup_position_in_grid ]], \ + ushort block_dim [[ threads_per_threadgroup ]] \ +) { \ + constant size_t *dims = {}; \ + constant size_t *strides = {}; \ + const bool STRIDED = false; \ + switch (block_dim) { \ + reduce_case(OP, T, T, 2048); \ + reduce_case(OP, T, T, 1024); \ + reduce_case(OP, T, T, 512); \ + reduce_case(OP, T, T, 256); \ + reduce_case(OP, T, T, 128); \ + reduce_case(OP, T, T, 64); \ + reduce_case(OP, T, T, 32); \ + reduce_case(OP, T, T, 16); \ + reduce_case(OP, T, T, 8); \ + reduce_case(OP, T, T, 4); \ + reduce_case(OP, T, T, 2); \ + reduce_case(OP, T, T, 1); \ + } \ +} \ +kernel void NAME##_strided( \ + constant size_t &num_dims, \ + constant size_t *dims, \ + constant size_t *strides, \ + constant size_t &el_to_sum_per_block, \ + device const T *src, \ + device T *dst, \ + constant size_t &num_elements, \ + ushort tid [[ thread_index_in_threadgroup ]], \ + ushort dst_id [[ threadgroup_position_in_grid ]], \ + ushort block_dim [[ threads_per_threadgroup ]] \ +) { \ + const bool STRIDED = true; \ + switch (block_dim) { \ + reduce_case(OP, T, T, 2048); \ + reduce_case(OP, T, T, 1024); \ + reduce_case(OP, T, T, 512); \ + reduce_case(OP, T, T, 256); \ + reduce_case(OP, T, T, 128); \ + reduce_case(OP, T, T, 64); \ + reduce_case(OP, T, T, 32); \ + reduce_case(OP, T, T, 16); \ + reduce_case(OP, T, T, 8); \ + reduce_case(OP, T, T, 4); \ + reduce_case(OP, T, T, 2); \ + reduce_case(OP, T, T, 1); \ + } \ +} template< typename T, - typename ArgReductionOp, - uint BLOCKSIZE, + typename ReductionOp, + ushort BLOCKSIZE, bool STRIDED > -METAL_FUNC void arg_block_reduce( +METAL_FUNC void reduce( constant size_t &num_dims, constant size_t *dims, constant size_t *strides, constant size_t &el_to_sum_per_block, device const T *src, device uint *dst, - threadgroup T shared[BLOCKSIZE], - threadgroup uint shared_indices[BLOCKSIZE], - uint id [[ thread_position_in_grid ]], - uint tid [[ thread_index_in_threadgroup ]], - uint dst_id [[ threadgroup_position_in_grid ]], - uint block_dim [[ threads_per_threadgroup ]] + constant size_t &num_elements, + threadgroup Indexed shared[BLOCKSIZE], + ushort tid [[ thread_index_in_threadgroup ]], + ushort dst_id [[ threadgroup_position_in_grid ]] ) { - ArgReductionOp arg_op; + // Initialize shared memory for current thread to correct value for reduction operation + shared[tid] = ReductionOp::init(); - shared[tid] = ArgReductionOp::init; - shared_indices[tid] = numeric_limits::max(); + // Calcluate offset for the threadgroup of current thread + ushort offset = dst_id * el_to_sum_per_block; + Indexed initial = ReductionOp::init(); + // Load with reduction from global memory into shared memory + shared[tid] = load_from_global, ReductionOp, BLOCKSIZE, STRIDED>( + initial, + num_elements, + num_dims, + dims, + strides, + el_to_sum_per_block, + src, + offset, + shared, + tid + ); + // Threadgroup barrier is needed to ensure that all threads have written to shared memory + // Memory space is not shared between threadgroups so we can use the mem_none flag for all threadgroup barriers. + threadgroup_barrier(mem_flags::mem_none); - if (STRIDED) { - load_from_global( - num_dims, - dims, - strides, - el_to_sum_per_block, - src, - shared, - shared_indices, - tid, - dst_id, - block_dim - ); - } else { - load_from_global( - num_dims, - dims, - el_to_sum_per_block, - src, - shared, - shared_indices, - tid, - dst_id, - block_dim - ); - } - arg_reduce_block(1024); - arg_reduce_block(512); - arg_reduce_block(256); - arg_reduce_block(128); + // Complete reduction + Indexed value = threadgroup_reduce>(shared, tid); - if (tid < 32) { - threadgroup_reduce(shared, shared_indices, tid, block_dim); - threadgroup_barrier(mem_flags::mem_threadgroup); - } - - if (tid == 0) { - dst[dst_id] = shared_indices[0]; - } -} -#undef arg_reduce_block - -#define ARG_REDUCE(OP, NAME, T) \ -kernel void NAME( \ - constant size_t &num_dims, \ - constant size_t *dims, \ - constant size_t *strides, \ - constant size_t &el_to_sum_per_block, \ - device const T *src, \ - device uint *dst, \ - uint id [[ thread_position_in_grid ]], \ - uint tid [[ thread_index_in_threadgroup ]], \ - uint dst_id [[ threadgroup_position_in_grid ]], \ - uint block_dim [[ threads_per_threadgroup ]] \ -) { \ - threadgroup T shared[BLOCKSIZE]; \ - threadgroup uint shared_indices[BLOCKSIZE]; \ - arg_block_reduce, BLOCKSIZE, false>( \ - num_dims, \ - dims, \ - strides, \ - el_to_sum_per_block, \ - src, \ - dst, \ - shared, \ - shared_indices, \ - id, \ - tid, \ - dst_id, \ - block_dim); \ -} \ -kernel void NAME##_strided( \ - constant size_t &num_dims, \ - constant size_t *dims, \ - constant size_t *strides, \ - constant size_t &el_to_sum_per_block, \ - device const T *src, \ - device uint *dst, \ - uint id [[ thread_position_in_grid ]], \ - uint tid [[ thread_index_in_threadgroup ]], \ - uint dst_id [[ threadgroup_position_in_grid ]], \ - uint block_dim [[ threads_per_threadgroup ]] \ -) { \ - threadgroup T shared[BLOCKSIZE]; \ - threadgroup uint shared_indices[BLOCKSIZE]; \ - arg_block_reduce, BLOCKSIZE, true>( \ - num_dims, \ - dims, \ - strides, \ - el_to_sum_per_block, \ - src, \ - dst, \ - shared, \ - shared_indices, \ - id, \ - tid, \ - dst_id, \ - block_dim); \ + // Return index of reduce result + if (tid == 0) dst[dst_id] = value.i; } - -#define MAX(x, y) ((x) > (y) ? (x) : (y)) -#define MIN(x, y) ((x) < (y) ? (x) : (y)) - - -#define softmax_max_block(SIZE) \ -if (BLOCKSIZE >= SIZE) { \ - if (tid < SIZE / 2 && block_dim >= SIZE) { \ - shared[tid] = max_op(shared[tid], shared[tid + SIZE / 2]); \ - } \ - threadgroup_barrier(mem_flags::mem_threadgroup); \ +#define arg_reduce_case(OP, T, N) \ +case N: { \ + threadgroup Indexed shared[N]; \ + reduce>, N, STRIDED>( \ + num_dims, \ + dims, \ + strides, \ + el_to_sum_per_block, \ + src, \ + dst, \ + num_elements, \ + shared, \ + tid, \ + dst_id); \ + break; \ } -#define softmax_acc_block(SIZE) \ -if (BLOCKSIZE >= SIZE) { \ - if (tid < SIZE / 2 && block_dim >= SIZE) { \ - shared[tid] += shared[tid + SIZE / 2]; \ - } \ - threadgroup_barrier(mem_flags::mem_threadgroup); \ +#define impl_arg_reduce(OP, NAME, T) \ +kernel void NAME( \ + constant size_t &num_dims, \ + constant size_t &el_to_sum_per_block, \ + device const T *src, \ + device uint *dst, \ + constant size_t &num_elements, \ + ushort tid [[ thread_index_in_threadgroup ]], \ + ushort dst_id [[ threadgroup_position_in_grid ]], \ + ushort block_dim [[ threads_per_threadgroup ]] \ +) { \ + constant size_t *dims = {}; \ + constant size_t *strides = {}; \ + const bool STRIDED = false; \ + switch (block_dim) { \ + arg_reduce_case(OP, T, 2048); \ + arg_reduce_case(OP, T, 1024); \ + arg_reduce_case(OP, T, 512); \ + arg_reduce_case(OP, T, 256); \ + arg_reduce_case(OP, T, 128); \ + arg_reduce_case(OP, T, 64); \ + arg_reduce_case(OP, T, 32); \ + arg_reduce_case(OP, T, 16); \ + arg_reduce_case(OP, T, 8); \ + arg_reduce_case(OP, T, 4); \ + arg_reduce_case(OP, T, 2); \ + arg_reduce_case(OP, T, 1); \ + } \ +} \ +kernel void NAME##_strided( \ + constant size_t &num_dims, \ + constant size_t *dims, \ + constant size_t *strides, \ + constant size_t &el_to_sum_per_block, \ + device const T *src, \ + device uint *dst, \ + constant size_t &num_elements, \ + ushort tid [[ thread_index_in_threadgroup ]], \ + ushort dst_id [[ threadgroup_position_in_grid ]], \ + ushort block_dim [[ threads_per_threadgroup ]] \ +) { \ + const bool STRIDED = true; \ + switch (block_dim) { \ + arg_reduce_case(OP, T, 2048); \ + arg_reduce_case(OP, T, 1024); \ + arg_reduce_case(OP, T, 512); \ + arg_reduce_case(OP, T, 256); \ + arg_reduce_case(OP, T, 128); \ + arg_reduce_case(OP, T, 64); \ + arg_reduce_case(OP, T, 32); \ + arg_reduce_case(OP, T, 16); \ + arg_reduce_case(OP, T, 8); \ + arg_reduce_case(OP, T, 4); \ + arg_reduce_case(OP, T, 2); \ + arg_reduce_case(OP, T, 1); \ + } \ } template< typename T, - typename ACC, - uint BLOCKSIZE + typename ACC = float, + ushort BLOCKSIZE > METAL_FUNC void softmax( constant size_t &src_numel, constant size_t &el_to_sum_per_block, - device const T *src, + const device T *src, device T *dst, threadgroup ACC shared[BLOCKSIZE], - uint id [[ thread_position_in_grid ]], - uint tid [[ thread_index_in_threadgroup ]], - uint dst_id [[ threadgroup_position_in_grid ]], - uint block_dim [[ threads_per_threadgroup ]] + ushort tid [[ thread_index_in_threadgroup ]], + ushort dst_id [[ threadgroup_position_in_grid ]] ) { - Max max_op; + // Initialize shared memory for current thread to lowest value + shared[tid] = numeric_limits::lowest(); - shared[tid] = numeric_limits::min(); - ACC tmp = numeric_limits::min(); + // Calcluate offset for the threadgroup of current thread + size_t offset = dst_id * el_to_sum_per_block; + ACC initial = numeric_limits::lowest(); + // Load with reduction from global memory into shared memory + shared[tid] = load_from_global, BLOCKSIZE>( + initial, + src_numel, + el_to_sum_per_block, + src, + offset, + shared, + tid + ); + // Threadgroup barrier is needed to ensure that all threads have written to shared memory + // Memory space is not shared between threadgroups so we can use the mem_none flag for all threadgroup barriers. + threadgroup_barrier(mem_flags::mem_none); - size_t start_idx = dst_id * el_to_sum_per_block; - size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); - size_t idx = start_idx + tid; + // Reduce shared memory to find max value + threadgroup_reduce, BLOCKSIZE>(shared, tid); + ACC max_result = shared[0]; - while (idx < stop_idx) { - tmp = max_op(tmp, static_cast(src[idx])); - idx += block_dim; - } - shared[tid] = tmp; - threadgroup_barrier(mem_flags::mem_threadgroup); - - softmax_max_block(1024); - softmax_max_block(512); - softmax_max_block(256); - softmax_max_block(128); - if (tid < 32) { - threadgroup_reduce, BLOCKSIZE>(shared, tid, block_dim); - } - threadgroup_barrier(mem_flags::mem_threadgroup); - ACC _max = shared[0]; - - // prevent tid 0 from overwriting _max before other threads have written - threadgroup_barrier(mem_flags::mem_threadgroup); + // Ensure all threads have max_result = shared[0] before we set shared[0] = 0. + threadgroup_barrier(mem_flags::mem_none); shared[tid] = 0; - idx = start_idx + tid; + // Calculate softmax values + size_t stop_idx = min(offset + el_to_sum_per_block, src_numel); + size_t idx = offset + tid; while (idx < stop_idx) { - const ACC val = exp(static_cast(src[idx]) - _max); - dst[idx] = static_cast(val); + const ACC val = exp(ACC(src[idx]) - max_result); + dst[idx] = T(val); shared[tid] += val; - - idx += block_dim; + idx += BLOCKSIZE; } - threadgroup_barrier(mem_flags::mem_threadgroup); + threadgroup_barrier(mem_flags::mem_none); - softmax_acc_block(1024); - softmax_acc_block(512); - softmax_acc_block(256); - softmax_acc_block(128); - if (tid < 32) { - threadgroup_reduce, BLOCKSIZE>(shared, tid, block_dim); - threadgroup_barrier(mem_flags::mem_none); - } + threadgroup_reduce, BLOCKSIZE>(shared, tid); + threadgroup_barrier(mem_flags::mem_none); const T inv_acc = T(1.0/shared[0]); - idx = start_idx + tid; + idx = offset + tid; while (idx < stop_idx) { dst[idx] *= inv_acc; - idx += block_dim; + idx += BLOCKSIZE; } } - -#define SOFTMAX(NAME, T, ACC) \ -kernel void NAME( \ - constant size_t &src_numel, \ - constant size_t &el_to_sum_per_block, \ - device const T *src, \ - device T *dst, \ - \ - uint id [[ thread_position_in_grid ]], \ - uint tid [[ thread_index_in_threadgroup ]], \ - uint dst_id [[ threadgroup_position_in_grid ]], \ - uint block_dim [[ threads_per_threadgroup ]] \ -) { \ - threadgroup ACC shared_memory[BLOCKSIZE]; \ - softmax( \ - src_numel, \ - el_to_sum_per_block, \ - src, \ - dst, \ - shared_memory, \ - id, \ - tid, \ - dst_id, \ - block_dim); \ +#define softmax_case(T, ACC, N) \ +case N: { \ + threadgroup ACC shared[N]; \ + softmax( \ + src_numel, \ + el_to_sum_per_block, \ + src, \ + dst, \ + shared, \ + tid, \ + dst_id); \ + break; \ } -REDUCE(Sum, fast_sum_f32, float) -REDUCE(Sum, fast_sum_u32, uint) -REDUCE(Sum, fast_sum_f16, half) -REDUCE(Sum, fast_sum_u8, uint8_t) +#define impl_softmax(NAME, T, ACC) \ +kernel void NAME( \ + constant size_t &src_numel, \ + constant size_t &el_to_sum_per_block, \ + device const T *src, \ + device T *dst, \ + \ + ushort tid [[ thread_index_in_threadgroup ]], \ + ushort dst_id [[ threadgroup_position_in_grid ]], \ + ushort block_dim [[ threads_per_threadgroup ]] \ +) { \ + switch (block_dim) { \ + softmax_case(T, ACC, 2048); \ + softmax_case(T, ACC, 1024); \ + softmax_case(T, ACC, 512); \ + softmax_case(T, ACC, 256); \ + softmax_case(T, ACC, 128); \ + softmax_case(T, ACC, 64); \ + softmax_case(T, ACC, 32); \ + softmax_case(T, ACC, 16); \ + softmax_case(T, ACC, 8); \ + softmax_case(T, ACC, 4); \ + softmax_case(T, ACC, 2); \ + softmax_case(T, ACC, 1); \ + } \ +} -REDUCE(Mul, fast_mul_f32, float) -REDUCE(Mul, fast_mul_u32, uint) -REDUCE(Mul, fast_mul_f16, half) -REDUCE(Mul, fast_mul_u8, uint8_t) +impl_reduce(Sum, fast_sum_f32, float) +impl_reduce(Sum, fast_sum_u32, uint) +impl_reduce(Sum, fast_sum_f16, half) +impl_reduce(Sum, fast_sum_u8, uint8_t) -REDUCE(Max, fast_max_f32, float) -REDUCE(Max, fast_max_u32, uint) -REDUCE(Max, fast_max_f16, half) -REDUCE(Max, fast_max_u8, uint8_t) +impl_reduce(Mul, fast_mul_f32, float) +impl_reduce(Mul, fast_mul_u32, uint) +impl_reduce(Mul, fast_mul_f16, half) +impl_reduce(Mul, fast_mul_u8, uint8_t) -REDUCE(Min, fast_min_f32, float) -REDUCE(Min, fast_min_u32, uint) -REDUCE(Min, fast_min_f16, half) -REDUCE(Min, fast_min_u8, uint8_t) +impl_reduce(Max, fast_max_f32, float) +impl_reduce(Max, fast_max_u32, uint) +impl_reduce(Max, fast_max_f16, half) +impl_reduce(Max, fast_max_u8, uint8_t) -ARG_REDUCE(ArgMin, fast_argmin_f32, float) -ARG_REDUCE(ArgMin, fast_argmin_f16, half) -ARG_REDUCE(ArgMin, fast_argmin_u32, uint) -ARG_REDUCE(ArgMin, fast_argmin_u8, uint8_t) +impl_reduce(Min, fast_min_f32, float) +impl_reduce(Min, fast_min_u32, uint) +impl_reduce(Min, fast_min_f16, half) +impl_reduce(Min, fast_min_u8, uint8_t) -ARG_REDUCE(ArgMax, fast_argmax_f32, float) -ARG_REDUCE(ArgMax, fast_argmax_f16, half) -ARG_REDUCE(ArgMax, fast_argmax_u32, uint) -ARG_REDUCE(ArgMax, fast_argmax_u8, uint8_t) +impl_arg_reduce(Min, fast_argmin_f32, float) +impl_arg_reduce(Min, fast_argmin_f16, half) +impl_arg_reduce(Min, fast_argmin_u32, uint) +impl_arg_reduce(Min, fast_argmin_u8, uint8_t) -SOFTMAX(softmax_f32, float, float) -SOFTMAX(softmax_f16, half, float) +impl_arg_reduce(Max, fast_argmax_f32, float) +impl_arg_reduce(Max, fast_argmax_f16, half) +impl_arg_reduce(Max, fast_argmax_u32, uint) +impl_arg_reduce(Max, fast_argmax_u8, uint8_t) + +impl_softmax(softmax_f32, float, float) +impl_softmax(softmax_f16, half, float) #if __METAL_VERSION__ >= 220 -REDUCE(Sum, fast_sum_i64, int64_t) -REDUCE(Mul, fast_mul_i64, int64_t) -REDUCE(Min, fast_min_i64, int64_t) -REDUCE(Max, fast_max_i64, int64_t) +impl_reduce(Sum, fast_sum_i64, int64_t) +impl_reduce(Mul, fast_mul_i64, int64_t) +impl_reduce(Min, fast_min_i64, int64_t) +impl_reduce(Max, fast_max_i64, int64_t) -ARG_REDUCE(ArgMin, fast_argmin_i64, int64_t) -ARG_REDUCE(ArgMax, fast_argmax_i64, int64_t) +impl_arg_reduce(Min, fast_argmin_i64, int64_t) +impl_arg_reduce(Max, fast_argmax_i64, int64_t) #endif #if defined(__HAVE_BFLOAT__) -REDUCE(Sum, fast_sum_bf16, bfloat) -REDUCE(Mul, fast_mul_bf16, bfloat) -REDUCE(Max, fast_max_bf16, bfloat) -REDUCE(Min, fast_min_bf16, bfloat) +impl_reduce(Sum, fast_sum_bf16, bfloat) +impl_reduce(Mul, fast_mul_bf16, bfloat) +impl_reduce(Max, fast_max_bf16, bfloat) +impl_reduce(Min, fast_min_bf16, bfloat) -ARG_REDUCE(ArgMin, fast_argmin_bf16, bfloat) -ARG_REDUCE(ArgMax, fast_argmax_bf16, bfloat) +impl_arg_reduce(Min, fast_argmin_bf16, bfloat) +impl_arg_reduce(Max, fast_argmax_bf16, bfloat) -SOFTMAX(softmax_bf16, bfloat, float) +impl_softmax(softmax_bf16, bfloat, float) #endif diff --git a/candle-metal-kernels/src/reduce_old.metal b/candle-metal-kernels/src/reduce_old.metal new file mode 100644 index 00000000..1e5c2895 --- /dev/null +++ b/candle-metal-kernels/src/reduce_old.metal @@ -0,0 +1,346 @@ +#include +using namespace metal; + +#define MAX(x, y) ((x) > (y) ? (x) : (y)) +#define MIN(x, y) ((x) < (y) ? (x) : (y)) + +METAL_FUNC uint get_strided_index( + uint idx, + constant size_t &num_dims, + constant size_t *dims, + constant size_t *strides +) { + uint strided_i = 0; + for (uint d = 0; d < num_dims; d++) { + uint dim_idx = num_dims - 1 - d; + strided_i += (idx % dims[dim_idx]) * strides[dim_idx]; + idx /= dims[dim_idx]; + } + return strided_i; +} + +constant int THREADGROUP_SIZE = 2048; + + +#define ARGMIN(NAME, T, MAXVALUE) \ +kernel void NAME( \ + constant size_t &num_dims, \ + constant size_t *dims, \ + constant size_t *strides, \ + constant size_t &el_to_sum_per_block, \ + device const T *src, \ + device uint *dst, \ + uint id [[ thread_position_in_grid ]], \ + uint tid [[ thread_index_in_threadgroup ]], \ + uint dst_id [[ threadgroup_position_in_grid ]], \ + uint block_dim [[ threads_per_threadgroup ]] \ +) { \ + \ + threadgroup T shared_memory[THREADGROUP_SIZE]; \ + threadgroup uint shared_indices[THREADGROUP_SIZE]; \ + \ + shared_memory[tid] = MAXVALUE; \ + shared_indices[tid] = 0xFFFFFFFF; \ + bool notset = true; \ + /* \ + // Elements summed in this block range from dst_id * el_to_sum_per_block \ + // to (dst_id + 1) * el_to_sum_per_block. \ + */ \ + size_t start_idx = dst_id * el_to_sum_per_block; \ + size_t stop_idx = start_idx + el_to_sum_per_block; \ + size_t idx = start_idx + tid; \ + while (idx < stop_idx) { \ + /* \ + // TODO: Fast version for the contiguous case. \ + */ \ + size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \ + if (notset || src[strided_i] < shared_memory[tid]) { \ + shared_memory[tid] = src[strided_i]; \ + /* Assume that the reduction takes place over the last dimension which is contiguous. */ \ + shared_indices[tid] = idx % dims[num_dims - 1]; \ + notset = false; \ + } \ + idx += block_dim; \ + } \ + \ + threadgroup_barrier(mem_flags::mem_none); \ + \ + /* \ + // reduction in shared memory \ + */ \ + for (uint s = block_dim / 2; s > 0; s >>= 1) { \ + if (tid < s && shared_memory[tid + s] < shared_memory[tid]) { \ + shared_indices[tid] = shared_indices[tid + s]; \ + shared_memory[tid] = shared_memory[tid + s]; \ + } \ + threadgroup_barrier(mem_flags::mem_none); \ + } \ + \ + if (tid == 0){ \ + dst[dst_id] = shared_indices[0]; \ + } \ +} \ + + +#define ARGMAX(NAME, T, MINVALUE) \ +kernel void NAME( \ + constant size_t &num_dims, \ + constant size_t *dims, \ + constant size_t *strides, \ + constant size_t &el_to_sum_per_block, \ + device const T *src, \ + device uint *dst, \ + uint id [[ thread_position_in_grid ]], \ + uint tid [[ thread_index_in_threadgroup ]], \ + uint dst_id [[ threadgroup_position_in_grid ]], \ + uint block_dim [[ threads_per_threadgroup ]] \ +) { \ + \ + threadgroup T shared_memory[THREADGROUP_SIZE]; \ + threadgroup uint shared_indices[THREADGROUP_SIZE]; \ + \ + shared_memory[tid] = MINVALUE; \ + shared_indices[tid] = 0xFFFFFFFF; \ + /* \ + // Elements summed in this block range from dst_id * el_to_sum_per_block \ + // to (dst_id + 1) * el_to_sum_per_block. \ + */ \ + size_t start_idx = dst_id * el_to_sum_per_block; \ + size_t stop_idx = start_idx + el_to_sum_per_block; \ + size_t idx = start_idx + tid; \ + bool notset = true; \ + while (idx < stop_idx) { \ + /* \ + // TODO: Fast version for the contiguous case. \ + */ \ + size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \ + if (notset || shared_memory[tid] < src[strided_i]) { \ + shared_memory[tid] = src[strided_i]; \ + shared_indices[tid] = idx % dims[num_dims - 1]; \ + notset = false; \ + } \ + idx += block_dim; \ + } \ + \ + threadgroup_barrier(mem_flags::mem_none); \ + \ + /* \ + // reduction in shared memory \ + */ \ + for (uint s = block_dim / 2; s > 0; s >>= 1) { \ + if (tid < s && shared_memory[tid + s] > shared_memory[tid]) { \ + shared_indices[tid] = shared_indices[tid + s]; \ + shared_memory[tid] = shared_memory[tid + s]; \ + } \ + threadgroup_barrier(mem_flags::mem_none); \ + } \ + \ + if (tid == 0){ \ + dst[dst_id] = shared_indices[0]; \ + } \ +} \ + +#define REDUCE(FN, NAME, T, START) \ +kernel void NAME( \ + constant size_t &num_dims, \ + constant size_t *dims, \ + constant size_t *strides, \ + constant size_t &el_to_sum_per_block, \ + device const T *src, \ + device T *dst, \ + uint id [[ thread_position_in_grid ]], \ + uint tid [[ thread_index_in_threadgroup ]], \ + uint dst_id [[ threadgroup_position_in_grid ]], \ + uint block_dim [[ threads_per_threadgroup ]] \ +) { \ + \ + threadgroup T shared_memory[THREADGROUP_SIZE]; \ + \ + shared_memory[tid] = START; \ + /* \ + // Elements summed in this block range from dst_id * el_to_sum_per_block \ + // to (dst_id + 1) * el_to_sum_per_block. \ + */ \ + size_t start_idx = dst_id * el_to_sum_per_block; \ + size_t stop_idx = start_idx + el_to_sum_per_block; \ + size_t idx = start_idx + tid; \ + while (idx < stop_idx) { \ + /* \ + // TODO: Fast version for the contiguous case. \ + */ \ + size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \ + T x = shared_memory[tid]; \ + T y = src[strided_i]; \ + shared_memory[tid] = FN; \ + idx += block_dim; \ + } \ + \ + threadgroup_barrier(mem_flags::mem_none); \ + \ + /* \ + // reduction in shared memory \ + */ \ + for (uint s = block_dim / 2; s > 0; s >>= 1) { \ + if (tid < s) { \ + T x = shared_memory[tid]; \ + T y = shared_memory[tid + s]; \ + shared_memory[tid] = FN; \ + } \ + threadgroup_barrier(mem_flags::mem_none); \ + } \ + \ + dst[dst_id] = shared_memory[0]; \ +} \ + + +#define SOFTMAX(NAME, T) \ +kernel void NAME( \ + constant size_t &src_numel, \ + constant size_t &el_to_sum_per_block, \ + device const T *src, \ + device T *dst, \ + \ + uint id [[ thread_position_in_grid ]], \ + uint tid [[ thread_index_in_threadgroup ]], \ + uint dst_id [[ threadgroup_position_in_grid ]], \ + uint block_dim [[ threads_per_threadgroup ]] \ +) { \ + threadgroup float shared_memory[THREADGROUP_SIZE]; \ + shared_memory[tid] = -INFINITY; \ + size_t start_idx = dst_id * el_to_sum_per_block; \ + size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \ + size_t idx = start_idx + tid; \ + \ + \ + float tmp = -INFINITY; \ + while (idx < stop_idx) { \ + tmp = MAX(tmp, float(src[idx])); \ + idx += block_dim; \ + } \ + shared_memory[tid] = tmp; \ + \ + threadgroup_barrier(mem_flags::mem_threadgroup); \ + \ + for (uint s = block_dim / 2; s > 0; s >>= 1) { \ + if (tid < s) { \ + shared_memory[tid] = MAX(shared_memory[tid], shared_memory[tid + s]); \ + } \ + threadgroup_barrier(mem_flags::mem_threadgroup); \ + } \ + \ + /* wait for shared_memory[0] to be filled */ \ + threadgroup_barrier(mem_flags::mem_threadgroup); \ + \ + float _max = shared_memory[0]; \ + \ + /* prevent tid=0 from overwriting _max before other threads have written it */ \ + threadgroup_barrier(mem_flags::mem_threadgroup); \ + shared_memory[tid] = 0; \ + \ + idx = start_idx + tid; \ + while (idx < stop_idx) { \ + const float val = exp(float(src[idx]) - _max); \ + dst[idx] = T(val); \ + shared_memory[tid] += val; \ + idx += block_dim; \ + } \ + threadgroup_barrier(mem_flags::mem_threadgroup); \ + for (uint s = block_dim / 2; s > 0; s >>= 1) { \ + if (tid < s) { \ + shared_memory[tid] += shared_memory[tid + s]; \ + } \ + threadgroup_barrier(mem_flags::mem_threadgroup); \ + } \ + \ + const T inv_acc = T(1.0/shared_memory[0]); \ + idx = start_idx + tid; \ + while (idx < stop_idx) { \ + dst[idx] *= inv_acc; \ + idx += block_dim; \ + } \ +} \ + +REDUCE(x + y, fast_sum_f32_strided, float, 0) +REDUCE(x + y, fast_sum_u32_strided, uint, 0) +REDUCE(x + y, fast_sum_f16_strided, half, 0) +REDUCE(x + y, fast_sum_u8_strided, uint8_t, 0) +REDUCE(x * y, fast_mul_f32_strided, float, 1) +REDUCE(x * y, fast_mul_u32_strided, uint, 1) +REDUCE(x * y, fast_mul_f16_strided, half, 1) +REDUCE(MAX(x, y), fast_max_f32_strided, float, -HUGE_VALF) +REDUCE(MAX(x, y), fast_max_u32_strided, uint, 0) +REDUCE(MAX(x, y), fast_max_f16_strided, half, -HUGE_VALH) +REDUCE(MAX(x, y), fast_max_u8_strided, uint8_t, 0) +REDUCE(MIN(x, y), fast_min_f32_strided, float, HUGE_VALF) +REDUCE(MIN(x, y), fast_min_u32_strided, uint, 0xFFFFFFFF) +REDUCE(MIN(x, y), fast_min_f16_strided, half, HUGE_VALH) +REDUCE(MIN(x, y), fast_min_u8_strided, uint8_t, 0xFF) +ARGMIN(fast_argmin_f32_strided, float, HUGE_VALF) +ARGMIN(fast_argmin_f16_strided, half, HUGE_VALH) +ARGMIN(fast_argmin_u32_strided, uint, 0xFFFFFFFF) +ARGMIN(fast_argmin_u8_strided, uint8_t, 0xFF) +ARGMAX(fast_argmax_f32_strided, float, -HUGE_VALF) +ARGMAX(fast_argmax_f16_strided, half, -HUGE_VALH) +ARGMAX(fast_argmax_u32_strided, uint, 0) +ARGMAX(fast_argmax_u8_strided, uint8_t, 0) + + +REDUCE(x + y, fast_sum_f32, float, 0) +REDUCE(x + y, fast_sum_u32, uint, 0) +REDUCE(x + y, fast_sum_f16, half, 0) +REDUCE(x + y, fast_sum_u8, uint8_t, 0) +REDUCE(x * y, fast_mul_f32, float, 1) +REDUCE(x * y, fast_mul_u32, uint, 1) +REDUCE(x * y, fast_mul_f16, half, 1) +REDUCE(MAX(x, y), fast_max_f32, float, -HUGE_VALF) +REDUCE(MAX(x, y), fast_max_u32, uint, 0) +REDUCE(MAX(x, y), fast_max_f16, half, -HUGE_VALH) +REDUCE(MAX(x, y), fast_max_u8, uint8_t, 0) +REDUCE(MIN(x, y), fast_min_f32, float, HUGE_VALF) +REDUCE(MIN(x, y), fast_min_u32, uint, 0xFFFFFFFF) +REDUCE(MIN(x, y), fast_min_f16, half, HUGE_VALH) +REDUCE(MIN(x, y), fast_min_u8, uint8_t, 0xFF) +ARGMIN(fast_argmin_f32, float, HUGE_VALF) +ARGMIN(fast_argmin_f16, half, HUGE_VALH) +ARGMIN(fast_argmin_u32, uint, 0xFFFFFFFF) +ARGMIN(fast_argmin_u8, uint8_t, 0xFF) +ARGMAX(fast_argmax_f32, float, -HUGE_VALF) +ARGMAX(fast_argmax_f16, half, -HUGE_VALH) +ARGMAX(fast_argmax_u32, uint, 0) +ARGMAX(fast_argmax_u8, uint8_t, 0) + +SOFTMAX(softmax_f32, float) +SOFTMAX(softmax_f16, half) + +#if __METAL_VERSION__ >= 220 +REDUCE(x + y, fast_sum_i64_strided, int64_t, 0) +REDUCE(MIN(x, y), fast_min_i64_strided, int64_t, INT_MAX) +REDUCE(MAX(x, y), fast_max_i64_strided, int64_t, INT_MIN) +ARGMIN(fast_argmin_i64_strided, int64_t, INT_MAX) +ARGMAX(fast_argmax_i64_strided, int64_t, INT_MIN) + + +REDUCE(x + y, fast_sum_i64, int64_t, 0) +REDUCE(MIN(x, y), fast_min_i64, int64_t, INT_MAX) +REDUCE(MAX(x, y), fast_max_i64, int64_t, INT_MIN) +ARGMIN(fast_argmin_i64, int64_t, INT_MAX) +ARGMAX(fast_argmax_i64, int64_t, INT_MIN) +#endif + +#if defined(__HAVE_BFLOAT__) +REDUCE(x + y, fast_sum_bf16_strided, bfloat, 0) +REDUCE(x * y, fast_mul_bf16_strided, bfloat, 1) +REDUCE(MAX(x, y), fast_max_bf16_strided, bfloat, -HUGE_VALBF) +REDUCE(MIN(x, y), fast_min_bf16_strided, bfloat, HUGE_VALBF) +ARGMIN(fast_argmin_bf16_strided, bfloat, HUGE_VALBF) +ARGMAX(fast_argmax_bf16_strided, bfloat, -HUGE_VALBF) + +REDUCE(x + y, fast_sum_bf16, bfloat, 0) +REDUCE(x * y, fast_mul_bf16, bfloat, 1) +REDUCE(MAX(x, y), fast_max_bf16, bfloat, -HUGE_VALBF) +REDUCE(MIN(x, y), fast_min_bf16, bfloat, HUGE_VALBF) +ARGMIN(fast_argmin_bf16, bfloat, HUGE_VALBF) +ARGMAX(fast_argmax_bf16, bfloat, -HUGE_VALBF) + +SOFTMAX(softmax_bf16, bfloat) +#endif \ No newline at end of file diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index 655161e5..8cbf3e04 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -622,7 +622,7 @@ fn cos_f16() { assert_eq!(approx_f16(expected, 2), vec![0.54, -0.42, -0.99]); } -fn run_reduce(v: &[T], out_length: usize, name: &'static str) -> Vec { +fn run_reduce(v: &[T], out_length: usize, name: &'static str) -> Vec { let device = device(); let kernels = Kernels::new(); let command_queue = device.new_command_queue(); @@ -630,10 +630,10 @@ fn run_reduce(v: &[T], out_length: usize, name: &'static str) -> Vec()) as u64, options); + let output = device.new_buffer((out_length * core::mem::size_of::()) as u64, options); let dims = vec![v.len()]; let strides = vec![1]; - call_reduce_strided( + match call_reduce_strided( &device, command_buffer, &kernels, @@ -644,8 +644,13 @@ fn run_reduce(v: &[T], out_length: usize, name: &'static str) -> Vec {} + Err(e) => { + println!("Error: {}", e); + panic!(); + } + } command_buffer.commit(); command_buffer.wait_until_completed(); @@ -677,22 +682,114 @@ fn run_softmax(v: &[T], last_dim: usize, name: &'sta read_to_vec(&output, v.len()) } -#[test] -fn reduce_sum() { - let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; - let out_length = 1; +const fn create_array() -> [f32; N] { + let mut array: [f32; N] = [0.0; N]; + let mut i = 1; + while i <= N { + array[i - 1] = i as f32; + i += 1; + } + array +} - let results = run_reduce(&v, out_length, "fast_sum_f32_strided"); - assert_eq!(approx(results, 4), vec![21.0]); +const fn correct_sum() -> [f32; D] { + let mut sum = 0; + let mut results: [f32; D] = [0.0; D]; + let mut i = 1; + let mut j = 1; + while i <= N { + sum += i; + i += 1; + if i > j * N / D { + results[j - 1] = sum as f32; + j += 1; + sum = 0; + } + } + results +} + +fn correct_argmax(arr: [f32; N]) -> [u32; D] { + let mut max = 0.0; + let mut max_index: u32 = 0; + let mut results: [u32; D] = [0; D]; + let mut i = 0; + let mut j = 1; + while i <= N { + if i >= (j * N / D) { + results[j - 1] = max_index; + max = 0.0; + max_index = 0; + j += 1; + } + if i == N { + break; + } + if arr[i] > max { + max = arr[i]; + max_index = i as u32; + } + i += 1; + } + results +} + +fn reduce_sum_case() { + let v = create_array::(); + let results = run_reduce(&v, D, "fast_sum_f32_strided"); + assert_eq!(approx(results, 4), correct_sum::()); +} + +fn reduce_argmax_case() { + let v = create_array::(); + let results: Vec = run_reduce(&v, D, "fast_argmax_f32_strided"); + assert_eq!(results, correct_argmax::(v)); } #[test] -fn reduce_sum2() { - let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; - let out_length = 2; +fn reduce_sum() { + reduce_sum_case::<6, 1>(); + reduce_sum_case::<10, 1>(); + reduce_sum_case::<64, 1>(); + reduce_sum_case::<128, 1>(); + reduce_sum_case::<256, 1>(); + reduce_sum_case::<512, 1>(); + reduce_sum_case::<1024, 1>(); + reduce_sum_case::<2048, 1>(); + reduce_sum_case::<4096, 1>(); - let results = run_reduce(&v, out_length, "fast_sum_f32_strided"); - assert_eq!(approx(results, 4), vec![6.0, 15.0]); + reduce_sum_case::<6, 2>(); + reduce_sum_case::<10, 2>(); + reduce_sum_case::<64, 2>(); + reduce_sum_case::<128, 2>(); + reduce_sum_case::<256, 2>(); + reduce_sum_case::<512, 2>(); + reduce_sum_case::<1024, 2>(); + reduce_sum_case::<2048, 2>(); + reduce_sum_case::<4096, 2>(); +} + +#[test] +fn reduce_argmax() { + reduce_argmax_case::<6, 1>(); + reduce_argmax_case::<10, 1>(); + reduce_argmax_case::<64, 1>(); + reduce_argmax_case::<128, 1>(); + reduce_argmax_case::<256, 1>(); + reduce_argmax_case::<512, 1>(); + reduce_argmax_case::<1024, 1>(); + reduce_argmax_case::<2048, 1>(); + reduce_argmax_case::<4096, 1>(); + + reduce_argmax_case::<6, 2>(); + reduce_argmax_case::<10, 2>(); + reduce_argmax_case::<64, 2>(); + reduce_argmax_case::<128, 2>(); + reduce_argmax_case::<256, 2>(); + reduce_argmax_case::<512, 2>(); + reduce_argmax_case::<1024, 2>(); + reduce_argmax_case::<2048, 2>(); + reduce_argmax_case::<4096, 2>(); } #[test]