Improve arg reduce and add contiguous impl

This commit is contained in:
Ivar Flakstad
2024-01-21 18:12:49 +01:00
parent d5902840e0
commit 1f4c54493e
4 changed files with 358 additions and 234 deletions

View File

@ -1,19 +1,25 @@
use candle_core::{DType, Tensor};
use crate::benchmarks::{bench_name, device, BenchDevice};
use candle_core::{DType, Device, Tensor};
use criterion::{black_box, criterion_group, Criterion, Throughput};
use std::time::Instant;
use crate::benchmarks::{bench_name, device, BenchDevice};
fn run(a: &Tensor) {
fn run_sum(a: &Tensor) {
a.sum(2).unwrap();
}
fn run_arg_min(a: &Tensor) {
a.argmin(2).unwrap();
}
fn criterion_benchmark(c: &mut Criterion) {
let device = device().unwrap();
run_reduce(c, &device);
run_arg_reduce(c, &device);
}
fn run_reduce(c: &mut Criterion, device: &Device) {
let b = 1;
let m = 2048;
let k = 2048;
let device = device().unwrap();
let a = Tensor::rand(-1000.0f32, 1000.0f32, (b, m, k), &device).unwrap();
let flops = b * m * k * DType::F32.size_in_bytes();
@ -24,7 +30,31 @@ fn criterion_benchmark(c: &mut Criterion) {
b.iter_custom(|iters| {
let start = Instant::now();
for _i in 0..iters {
run(black_box(&a));
run_sum(black_box(&a));
}
device.sync().unwrap();
start.elapsed()
})
});
group.finish();
}
fn run_arg_reduce(c: &mut Criterion, device: &Device) {
let b = 1;
let m = 2048;
let k = 2048;
let a = Tensor::rand(-1000.0f32, 1000.0f32, (b, m, k), &device).unwrap();
let flops = b * m * k * DType::F32.size_in_bytes();
let mut group = c.benchmark_group(bench_name("arg_reduce"));
group.throughput(Throughput::Bytes(flops as u64));
group.bench_function("iter", move |b| {
b.iter_custom(|iters| {
let start = Instant::now();
for _i in 0..iters {
run_arg_min(black_box(&a));
}
device.sync().unwrap();
start.elapsed()

View File

@ -511,59 +511,56 @@ impl BackendStorage for MetalStorage {
(ReduceOp::Sum, DType::F32) => ("fast_sum_f32", false, false),
(ReduceOp::Min, DType::F32) => ("fast_min_f32", true, false),
(ReduceOp::Max, DType::F32) => ("fast_max_f32", true, false),
//(ReduceOp::ArgMin, DType::F32) => ("fast_argmin_f32", true, true),
//(ReduceOp::ArgMax, DType::F32) => ("fast_argmax_f32", true, true),
(ReduceOp::ArgMin, DType::F32) => ("fast_argmin_f32", true, true),
(ReduceOp::ArgMax, DType::F32) => ("fast_argmax_f32", true, true),
(ReduceOp::Sum, DType::U32) => ("fast_sum_u32", false, false),
(ReduceOp::Min, DType::U32) => ("fast_min_u32", true, false),
(ReduceOp::Max, DType::U32) => ("fast_max_u32", true, false),
//(ReduceOp::ArgMin, DType::U32) => ("fast_argmin_u32", true, true),
//(ReduceOp::ArgMax, DType::U32) => ("fast_argmax_u32", true, true),
(ReduceOp::ArgMin, DType::U32) => ("fast_argmin_u32", true, true),
(ReduceOp::ArgMax, DType::U32) => ("fast_argmax_u32", true, true),
(ReduceOp::Sum, DType::F16) => ("fast_sum_f16", false, false),
(ReduceOp::Min, DType::F16) => ("fast_min_f16", true, false),
(ReduceOp::Max, DType::F16) => ("fast_max_f16", true, false),
//(ReduceOp::ArgMin, DType::F16) => ("fast_argmin_f16", true, true),
//(ReduceOp::ArgMax, DType::F16) => ("fast_argmax_f16", true, true),
(ReduceOp::ArgMin, DType::F16) => ("fast_argmin_f16", true, true),
(ReduceOp::ArgMax, DType::F16) => ("fast_argmax_f16", true, true),
(ReduceOp::Sum, DType::BF16) => ("fast_sum_bf16", false, false),
(ReduceOp::Min, DType::BF16) => ("fast_min_bf16", true, false),
(ReduceOp::Max, DType::BF16) => ("fast_max_bf16", true, false),
//(ReduceOp::ArgMin, DType::BF16) => ("fast_argmin_bf16", true, true),
//(ReduceOp::ArgMax, DType::BF16) => ("fast_argmax_bf16", true, true),
(ReduceOp::ArgMin, DType::BF16) => ("fast_argmin_bf16", true, true),
(ReduceOp::ArgMax, DType::BF16) => ("fast_argmax_bf16", true, true),
(ReduceOp::Sum, DType::I64) => ("fast_sum_i64", false, false),
(ReduceOp::Min, DType::I64) => ("fast_min_i64", true, false),
(ReduceOp::Max, DType::I64) => ("fast_max_i64", true, false),
//(ReduceOp::ArgMin, DType::I64) => ("fast_argmin_i64", true, true),
//(ReduceOp::ArgMax, DType::I64) => ("fast_argmax_i64", true, true),
(ReduceOp::ArgMin, DType::I64) => ("fast_argmin_i64", true, true),
(ReduceOp::ArgMax, DType::I64) => ("fast_argmax_i64", true, true),
(ReduceOp::Sum, DType::U8) => ("fast_sum_u8", false, false),
(ReduceOp::Min, DType::U8) => ("fast_min_u8", true, false),
(ReduceOp::Max, DType::U8) => ("fast_max_u8", true, false),
//(ReduceOp::ArgMin, DType::U8) => ("fast_argmin_u8", true, true),
//(ReduceOp::ArgMax, DType::U8) => ("fast_argmax_u8", true, true),
//(k, dtype) => crate::bail!("Metal reduce op {k:?} {dtype:?} not implemented"),
_ => ("fall back to strided impl", false, false)
};
if name != "fall back to strided impl" {
if check_empty && layout.shape().elem_count() == 0 {
Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
(ReduceOp::ArgMin, DType::U8) => ("fast_argmin_u8", true, true),
(ReduceOp::ArgMax, DType::U8) => ("fast_argmax_u8", true, true),
(k, dtype) => {
crate::bail!("Metal contiguous reduce op {k:?} {dtype:?} not implemented")
}
let buffer = device.new_buffer(1, self.dtype, "reduce")?;
let command_buffer = self.device.command_buffer()?;
candle_metal_kernels::call_reduce_contiguous(
&device.device,
&command_buffer,
&device.kernels,
name,
layout.shape().elem_count(),
dst_el,
&self.buffer,
layout.start_offset() * self.dtype.size_in_bytes(),
&buffer,
)
.map_err(MetalError::from)?;
return Ok(Self::new(buffer, device, self.dtype));
};
if check_empty && layout.shape().elem_count() == 0 {
Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
}
let buffer = device.new_buffer(1, self.dtype, "reduce")?;
let command_buffer = self.device.command_buffer()?;
candle_metal_kernels::call_reduce_contiguous(
&device.device,
&command_buffer,
&device.kernels,
name,
layout.shape().elem_count(),
dst_el,
&self.buffer,
layout.start_offset() * self.dtype.size_in_bytes(),
&buffer,
)
.map_err(MetalError::from)?;
return Ok(Self::new(buffer, device, self.dtype));
}
for &dim_idx in sum_dims.iter() {
@ -602,7 +599,7 @@ impl BackendStorage for MetalStorage {
(ReduceOp::Max, DType::U8) => ("fast_max_u8_strided", true, false),
(ReduceOp::ArgMin, DType::U8) => ("fast_argmin_u8_strided", true, true),
(ReduceOp::ArgMax, DType::U8) => ("fast_argmax_u8_strided", true, true),
(k, dtype) => crate::bail!("Metal reduce op {k:?} {dtype:?} not implemented"),
(k, dtype) => crate::bail!("Metal strided reduce op {k:?} {dtype:?} not implemented"),
};
if check_empty && layout.shape().elem_count() == 0 {
Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?

View File

@ -19,24 +19,24 @@ METAL_FUNC uint get_strided_index(
}
#define impl_reduction_op(name, op, init_val) \
template<typename T> \
template<typename T, typename R = T> \
struct name { \
\
static constexpr constant T init = init_val; \
\
METAL_FUNC T operator()(thread const T &a, thread const T &b) const { \
METAL_FUNC R operator()(thread const T &a, thread const T &b) const { \
return op; \
} \
\
METAL_FUNC T operator()(threadgroup const T &a, threadgroup const T &b) const { \
METAL_FUNC R operator()(threadgroup const T &a, threadgroup const T &b) const { \
return op; \
} \
\
METAL_FUNC T operator()(device const T &a, device const T &b) const { \
METAL_FUNC R operator()(device const T &a, device const T &b) const { \
return op; \
} \
\
METAL_FUNC T operator()(T a, T b) { \
METAL_FUNC R operator()(T a, T b) { \
return op; \
} \
} \
@ -45,10 +45,13 @@ impl_reduction_op(Sum, a + b, 0);
impl_reduction_op(Mul, a * b, 1);
impl_reduction_op(Min, a < b ? a : b, numeric_limits<T>::max());
impl_reduction_op(Max, a > b ? a : b, numeric_limits<T>::min());
impl_reduction_op(ArgMin, a < b, numeric_limits<T>::max());
impl_reduction_op(ArgMax, a > b, numeric_limits<T>::min());
#undef impl_reduction_op
static constant constexpr int THREADGROUP_SIZE = 2048;
// Load strided elements from global memory into shared memory.
template<typename T, typename ReductionOp, uint BLOCKSIZE>
METAL_FUNC void load_from_global(
constant size_t &num_dims,
@ -74,6 +77,40 @@ METAL_FUNC void load_from_global(
threadgroup_barrier(mem_flags::mem_none);
}
// Load strided elements from global memory into shared memory with indices.
template<typename T, typename ArgReductionOp, uint BLOCKSIZE>
METAL_FUNC void load_from_global(
constant size_t &num_dims,
constant size_t *dims,
constant size_t *strides,
constant size_t &el_to_sum_per_block,
device const T *src,
threadgroup T shared[BLOCKSIZE],
threadgroup uint shared_indices[BLOCKSIZE],
uint tid [[ thread_index_in_threadgroup ]],
uint dst_id [[ threadgroup_position_in_grid ]],
uint block_dim [[ threads_per_threadgroup ]]
) {
ArgReductionOp op;
bool notset = true;
size_t start_idx = dst_id * el_to_sum_per_block;
size_t stop_idx = start_idx + el_to_sum_per_block;
size_t idx = start_idx + tid;
while (idx < stop_idx) {
size_t strided_i = get_strided_index(idx, num_dims, dims, strides);
if (notset || op(src[strided_i], shared[tid])) {
shared[tid] = src[strided_i];
// Assume that the reduction takes place over the last dimension which is contiguous.
shared_indices[tid] = idx % dims[num_dims - 1];
notset = false;
}
idx += block_dim;
}
threadgroup_barrier(mem_flags::mem_none);
}
// Load contiguous elements from global memory into shared memory.
template<typename T, typename ReductionOp, uint BLOCKSIZE>
METAL_FUNC void load_from_global(
constant size_t &num_dims,
@ -97,6 +134,45 @@ METAL_FUNC void load_from_global(
threadgroup_barrier(mem_flags::mem_none);
}
// Load contiguous elements from global memory into shared memory with indices.
template<typename T, typename ArgReductionOp, uint BLOCKSIZE>
METAL_FUNC void load_from_global(
constant size_t &num_dims,
constant size_t *dims,
constant size_t &el_to_sum_per_block,
device const T *src,
threadgroup T shared[BLOCKSIZE],
threadgroup uint shared_indices[BLOCKSIZE],
uint tid [[ thread_index_in_threadgroup ]],
uint dst_id [[ threadgroup_position_in_grid ]],
uint block_dim [[ threads_per_threadgroup ]]
) {
ArgReductionOp op;
bool notset = true;
size_t start_idx = dst_id * el_to_sum_per_block;
size_t stop_idx = start_idx + el_to_sum_per_block;
size_t idx = start_idx + tid;
while (idx < stop_idx) {
if (notset || op(src[idx], shared[tid])) {
shared[tid] = src[idx];
// Assume that the reduction takes place over the last dimension which is contiguous.
shared_indices[tid] = idx % dims[num_dims - 1];
notset = false;
}
idx += block_dim;
}
threadgroup_barrier(mem_flags::mem_none);
}
#define reduce_threadgroup(SIZE) \
if (BLOCKSIZE >= SIZE) { \
if (block_dim >= SIZE) { \
shared[tid] = op(shared[tid], shared[tid + SIZE / 2]); \
threadgroup_barrier(mem_flags::mem_none); \
} \
}
template<typename T, typename ReductionOp, uint BLOCKSIZE>
METAL_FUNC void threadgroup_reduce(
threadgroup T shared[BLOCKSIZE],
@ -104,37 +180,50 @@ METAL_FUNC void threadgroup_reduce(
uint block_dim [[ threads_per_threadgroup ]]
) {
ReductionOp op;
if (BLOCKSIZE >= 64) {
if (block_dim >= 64) {
shared[tid] = op(shared[tid], shared[tid + 32]);
}
}
if (BLOCKSIZE >= 32) {
if (block_dim >= 32) {
shared[tid] = op(shared[tid], shared[tid + 16]);
}
}
if (BLOCKSIZE >= 16) {
if (block_dim >= 16) {
shared[tid] = op(shared[tid], shared[tid + 8]);
}
}
if (BLOCKSIZE >= 8) {
if (block_dim >= 8) {
shared[tid] = op(shared[tid], shared[tid + 4]);
}
}
if (BLOCKSIZE >= 4) {
if (block_dim >= 4) {
shared[tid] = op(shared[tid], shared[tid + 2]);
}
}
if (BLOCKSIZE >= 2) {
if (block_dim >= 2) {
shared[tid] = op(shared[tid], shared[tid + 1]);
}
}
reduce_threadgroup(64);
reduce_threadgroup(32);
reduce_threadgroup(16);
reduce_threadgroup(8);
reduce_threadgroup(4);
reduce_threadgroup(2);
}
#undef reduce_threadgroup
#define arg_reduce_threadgroup(SIZE) \
if (BLOCKSIZE >= SIZE) { \
if (block_dim >= SIZE && \
op(shared[tid], shared[tid + SIZE / 2]) \
) { \
shared_indices[tid] = shared_indices[tid + SIZE / 2]; \
shared[tid] = shared[tid + SIZE / 2]; \
threadgroup_barrier(mem_flags::mem_none); \
} \
}
template<typename T, typename ArgReductionOp, uint BLOCKSIZE>
METAL_FUNC void threadgroup_reduce(
threadgroup T shared[BLOCKSIZE],
threadgroup uint shared_indices[BLOCKSIZE],
uint tid [[thread_index_in_threadgroup]],
uint block_dim [[ threads_per_threadgroup ]]
) {
ArgReductionOp op;
arg_reduce_threadgroup(64);
arg_reduce_threadgroup(32);
arg_reduce_threadgroup(16);
arg_reduce_threadgroup(8);
arg_reduce_threadgroup(4);
arg_reduce_threadgroup(2);
}
#undef arg_reduce_threadgroup
#define reduce_block(SIZE) \
if (BLOCKSIZE >= SIZE) { \
if (tid < SIZE / 2 && block_dim >= SIZE) { \
shared[tid] = op(shared[tid], shared[tid + SIZE / 2]); \
threadgroup_barrier(mem_flags::mem_none); \
} \
} \
// Inspired by "Optimizing Parallel Reduction in CUDA" by Mark Harris
template<
@ -186,42 +275,20 @@ METAL_FUNC void block_reduce(
);
}
if (BLOCKSIZE >= 1024) {
if (tid < 512 && block_dim >= 1024) {
shared[tid] = op(shared[tid], shared[tid + 512]);
threadgroup_barrier(mem_flags::mem_none);
}
}
if (BLOCKSIZE >= 512) {
if (tid < 256 && block_dim >= 512) {
shared[tid] = op(shared[tid], shared[tid + 256]);
threadgroup_barrier(mem_flags::mem_none);
}
}
if (BLOCKSIZE >= 256) {
if (tid < 128 && block_dim >= 256) {
shared[tid] = op(shared[tid], shared[tid + 128]);
threadgroup_barrier(mem_flags::mem_none);
}
}
if (BLOCKSIZE >= 128) {
if (tid < 64 && block_dim >= 128) {
shared[tid] = op(shared[tid], shared[tid + 64]);
threadgroup_barrier(mem_flags::mem_none);
}
}
reduce_block(1024);
reduce_block(512);
reduce_block(256);
reduce_block(128);
if (tid < 32) {
threadgroup_reduce<T, ReductionOp, BLOCKSIZE>(shared, tid, block_dim);
threadgroup_barrier(mem_flags::mem_none);
}
if (tid == 0) {
dst[dst_id] = shared[tid];
}
}
#define MAX(x, y) ((x) > (y) ? (x) : (y))
#define MIN(x, y) ((x) < (y) ? (x) : (y))
#undef reduce_block
static constant constexpr int BLOCKSIZE = 2048;
@ -283,123 +350,146 @@ kernel void NAME##_strided( \
block_dim); \
} \
#define ARGMIN(NAME, T, MAXVALUE) \
kernel void NAME( \
constant size_t &num_dims, \
constant size_t *dims, \
constant size_t *strides, \
constant size_t &el_to_sum_per_block, \
device const T *src, \
device uint *dst, \
uint id [[ thread_position_in_grid ]], \
uint tid [[ thread_index_in_threadgroup ]], \
uint dst_id [[ threadgroup_position_in_grid ]], \
uint block_dim [[ threads_per_threadgroup ]] \
) { \
\
threadgroup T shared_memory[THREADGROUP_SIZE]; \
threadgroup uint shared_indices[THREADGROUP_SIZE]; \
\
shared_memory[tid] = MAXVALUE; \
shared_indices[tid] = 0xFFFFFFFF; \
bool notset = true; \
/* \
// Elements summed in this block range from dst_id * el_to_sum_per_block \
// to (dst_id + 1) * el_to_sum_per_block. \
*/ \
size_t start_idx = dst_id * el_to_sum_per_block; \
size_t stop_idx = start_idx + el_to_sum_per_block; \
size_t idx = start_idx + tid; \
while (idx < stop_idx) { \
/* \
// TODO: Fast version for the contiguous case. \
*/ \
size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
if (notset || src[strided_i] < shared_memory[tid]) { \
shared_memory[tid] = src[strided_i]; \
/* Assume that the reduction takes place over the last dimension which is contiguous. */ \
shared_indices[tid] = idx % dims[num_dims - 1]; \
notset = false; \
} \
idx += block_dim; \
} \
\
threadgroup_barrier(mem_flags::mem_none); \
\
/* \
// reduction in shared memory \
*/ \
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
if (tid < s && shared_memory[tid + s] < shared_memory[tid]) { \
shared_indices[tid] = shared_indices[tid + s]; \
shared_memory[tid] = shared_memory[tid + s]; \
} \
threadgroup_barrier(mem_flags::mem_none); \
} \
\
if (tid == 0){ \
dst[dst_id] = shared_indices[0]; \
} \
} \
#define arg_reduce_block(SIZE) \
if (BLOCKSIZE >= SIZE) { \
if (tid < SIZE / 2 \
&& block_dim >= SIZE \
&& arg_op(shared[tid], shared[tid + SIZE / 2]) \
) { \
shared_indices[tid] = shared_indices[tid + SIZE / 2]; \
shared[tid] = shared[tid + SIZE / 2]; \
threadgroup_barrier(mem_flags::mem_none); \
} \
} \
template<
typename T,
typename ArgReductionOp,
uint BLOCKSIZE,
bool STRIDED
>
METAL_FUNC void arg_block_reduce(
constant size_t &num_dims,
constant size_t *dims,
constant size_t *strides,
constant size_t &el_to_sum_per_block,
device const T *src,
device uint *dst,
threadgroup T shared[BLOCKSIZE],
threadgroup uint shared_indices[BLOCKSIZE],
uint id [[ thread_position_in_grid ]],
uint tid [[ thread_index_in_threadgroup ]],
uint dst_id [[ threadgroup_position_in_grid ]],
uint block_dim [[ threads_per_threadgroup ]]
) {
ArgReductionOp arg_op;
shared[tid] = ArgReductionOp::init;
shared_indices[tid] = numeric_limits<uint>::max();
if (STRIDED) {
load_from_global<T, ArgReductionOp, BLOCKSIZE>(
num_dims,
dims,
strides,
el_to_sum_per_block,
src,
shared,
shared_indices,
tid,
dst_id,
block_dim
);
} else {
load_from_global<T, ArgReductionOp, BLOCKSIZE>(
num_dims,
dims,
el_to_sum_per_block,
src,
shared,
shared_indices,
tid,
dst_id,
block_dim
);
}
arg_reduce_block(1024);
arg_reduce_block(512);
arg_reduce_block(256);
arg_reduce_block(128);
if (tid < 32) {
threadgroup_reduce<T, ArgReductionOp, BLOCKSIZE>(shared, shared_indices, tid, block_dim);
threadgroup_barrier(mem_flags::mem_none);
}
if (tid == 0) {
dst[dst_id] = shared_indices[0];
}
}
#undef arg_reduce_block
#define ARG_REDUCE(OP, NAME, T) \
kernel void NAME( \
constant size_t &num_dims, \
constant size_t *dims, \
constant size_t *strides, \
constant size_t &el_to_sum_per_block, \
device const T *src, \
device uint *dst, \
uint id [[ thread_position_in_grid ]], \
uint tid [[ thread_index_in_threadgroup ]], \
uint dst_id [[ threadgroup_position_in_grid ]], \
uint block_dim [[ threads_per_threadgroup ]] \
) { \
threadgroup T shared[BLOCKSIZE]; \
threadgroup uint shared_indices[BLOCKSIZE]; \
arg_block_reduce<T, OP<T, bool>, BLOCKSIZE, false>( \
num_dims, \
dims, \
strides, \
el_to_sum_per_block, \
src, \
dst, \
shared, \
shared_indices, \
id, \
tid, \
dst_id, \
block_dim); \
} \
kernel void NAME##_strided( \
constant size_t &num_dims, \
constant size_t *dims, \
constant size_t *strides, \
constant size_t &el_to_sum_per_block, \
device const T *src, \
device uint *dst, \
uint id [[ thread_position_in_grid ]], \
uint tid [[ thread_index_in_threadgroup ]], \
uint dst_id [[ threadgroup_position_in_grid ]], \
uint block_dim [[ threads_per_threadgroup ]] \
) { \
threadgroup T shared[BLOCKSIZE]; \
threadgroup uint shared_indices[BLOCKSIZE]; \
arg_block_reduce<T, OP<T, bool>, BLOCKSIZE, true>( \
num_dims, \
dims, \
strides, \
el_to_sum_per_block, \
src, \
dst, \
shared, \
shared_indices, \
id, \
tid, \
dst_id, \
block_dim); \
}
#define ARGMAX(NAME, T, MINVALUE) \
kernel void NAME( \
constant size_t &num_dims, \
constant size_t *dims, \
constant size_t *strides, \
constant size_t &el_to_sum_per_block, \
device const T *src, \
device uint *dst, \
uint id [[ thread_position_in_grid ]], \
uint tid [[ thread_index_in_threadgroup ]], \
uint dst_id [[ threadgroup_position_in_grid ]], \
uint block_dim [[ threads_per_threadgroup ]] \
) { \
\
threadgroup T shared_memory[THREADGROUP_SIZE]; \
threadgroup uint shared_indices[THREADGROUP_SIZE]; \
\
shared_memory[tid] = MINVALUE; \
shared_indices[tid] = 0xFFFFFFFF; \
/* \
// Elements summed in this block range from dst_id * el_to_sum_per_block \
// to (dst_id + 1) * el_to_sum_per_block. \
*/ \
size_t start_idx = dst_id * el_to_sum_per_block; \
size_t stop_idx = start_idx + el_to_sum_per_block; \
size_t idx = start_idx + tid; \
bool notset = true; \
while (idx < stop_idx) { \
/* \
// TODO: Fast version for the contiguous case. \
*/ \
size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
if (notset || shared_memory[tid] < src[strided_i]) { \
shared_memory[tid] = src[strided_i]; \
shared_indices[tid] = idx % dims[num_dims - 1]; \
notset = false; \
} \
idx += block_dim; \
} \
\
threadgroup_barrier(mem_flags::mem_none); \
\
/* \
// reduction in shared memory \
*/ \
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
if (tid < s && shared_memory[tid + s] > shared_memory[tid]) { \
shared_indices[tid] = shared_indices[tid + s]; \
shared_memory[tid] = shared_memory[tid + s]; \
} \
threadgroup_barrier(mem_flags::mem_none); \
} \
\
if (tid == 0){ \
dst[dst_id] = shared_indices[0]; \
} \
} \
#define MAX(x, y) ((x) > (y) ? (x) : (y))
#define MIN(x, y) ((x) < (y) ? (x) : (y))
#define SOFTMAX(NAME, T) \
kernel void NAME( \
@ -472,26 +562,31 @@ REDUCE(Sum, fast_sum_f32, float)
REDUCE(Sum, fast_sum_u32, uint)
REDUCE(Sum, fast_sum_f16, half)
REDUCE(Sum, fast_sum_u8, uint8_t)
REDUCE(Mul, fast_mul_f32, float)
REDUCE(Mul, fast_mul_u32, uint)
REDUCE(Mul, fast_mul_f16, half)
REDUCE(Mul, fast_mul_u8, uint8_t)
REDUCE(Max, fast_max_f32, float)
REDUCE(Max, fast_max_u32, uint)
REDUCE(Max, fast_max_f16, half)
REDUCE(Max, fast_max_u8, uint8_t)
REDUCE(Min, fast_min_f32, float)
REDUCE(Min, fast_min_u32, uint)
REDUCE(Min, fast_min_f16, half)
REDUCE(Min, fast_min_u8, uint8_t)
ARGMIN(fast_argmin_f32_strided, float, HUGE_VALF)
ARGMIN(fast_argmin_f16_strided, half, HUGE_VALH)
ARGMIN(fast_argmin_u32_strided, uint, 0xFFFFFFFF)
ARGMIN(fast_argmin_u8_strided, uint8_t, 0xFF)
ARGMAX(fast_argmax_f32_strided, float, -HUGE_VALF)
ARGMAX(fast_argmax_f16_strided, half, -HUGE_VALH)
ARGMAX(fast_argmax_u32_strided, uint, 0)
ARGMAX(fast_argmax_u8_strided, uint8_t, 0)
ARG_REDUCE(ArgMin, fast_argmin_f32, float)
ARG_REDUCE(ArgMin, fast_argmin_f16, half)
ARG_REDUCE(ArgMin, fast_argmin_u32, uint)
ARG_REDUCE(ArgMin, fast_argmin_u8, uint8_t)
ARG_REDUCE(ArgMax, fast_argmax_f32, float)
ARG_REDUCE(ArgMax, fast_argmax_f16, half)
ARG_REDUCE(ArgMax, fast_argmax_u32, uint)
ARG_REDUCE(ArgMax, fast_argmax_u8, uint8_t)
SOFTMAX(softmax_f32, float)
SOFTMAX(softmax_f16, half)
@ -502,8 +597,9 @@ REDUCE(Mul, fast_mul_i64, int64_t)
REDUCE(Min, fast_min_i64, int64_t)
REDUCE(Max, fast_max_i64, int64_t)
ARGMIN(fast_argmin_i64_strided, int64_t, INT_MAX)
ARGMAX(fast_argmax_i64_strided, int64_t, INT_MIN)
ARG_REDUCE(ArgMin, fast_argmin_i64, int64_t)
ARG_REDUCE(ArgMax, fast_argmax_i64, int64_t)
#endif
#if __METAL_VERSION__ >= 310
@ -512,7 +608,8 @@ REDUCE(Mul, fast_mul_bf16, bfloat)
REDUCE(Max, fast_max_bf16, bfloat)
REDUCE(Min, fast_min_bf16, bfloat)
ARGMIN(fast_argmin_bf16, bfloat, HUGE_VALBF)
ARGMAX(fast_argmax_bf16, bfloat, -HUGE_VALBF)
ARG_REDUCE(ArgMin, fast_argmin_bf16, bfloat)
ARG_REDUCE(ArgMax, fast_argmax_bf16, bfloat)
SOFTMAX(softmax_bf16, bfloat)
#endif