mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Improve arg reduce and add contiguous impl
This commit is contained in:
@ -1,4 +1,4 @@
|
||||
mod benchmarks;
|
||||
|
||||
use criterion::criterion_main;
|
||||
criterion_main!(benchmarks::reduce::benches);
|
||||
criterion_main!(benchmarks::reduce::benches);
|
||||
|
@ -1,19 +1,25 @@
|
||||
use candle_core::{DType, Tensor};
|
||||
use crate::benchmarks::{bench_name, device, BenchDevice};
|
||||
use candle_core::{DType, Device, Tensor};
|
||||
use criterion::{black_box, criterion_group, Criterion, Throughput};
|
||||
use std::time::Instant;
|
||||
use crate::benchmarks::{bench_name, device, BenchDevice};
|
||||
|
||||
fn run(a: &Tensor) {
|
||||
fn run_sum(a: &Tensor) {
|
||||
a.sum(2).unwrap();
|
||||
}
|
||||
fn run_arg_min(a: &Tensor) {
|
||||
a.argmin(2).unwrap();
|
||||
}
|
||||
|
||||
fn criterion_benchmark(c: &mut Criterion) {
|
||||
let device = device().unwrap();
|
||||
run_reduce(c, &device);
|
||||
run_arg_reduce(c, &device);
|
||||
}
|
||||
fn run_reduce(c: &mut Criterion, device: &Device) {
|
||||
let b = 1;
|
||||
let m = 2048;
|
||||
let k = 2048;
|
||||
|
||||
let device = device().unwrap();
|
||||
|
||||
let a = Tensor::rand(-1000.0f32, 1000.0f32, (b, m, k), &device).unwrap();
|
||||
|
||||
let flops = b * m * k * DType::F32.size_in_bytes();
|
||||
@ -24,7 +30,31 @@ fn criterion_benchmark(c: &mut Criterion) {
|
||||
b.iter_custom(|iters| {
|
||||
let start = Instant::now();
|
||||
for _i in 0..iters {
|
||||
run(black_box(&a));
|
||||
run_sum(black_box(&a));
|
||||
}
|
||||
device.sync().unwrap();
|
||||
start.elapsed()
|
||||
})
|
||||
});
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn run_arg_reduce(c: &mut Criterion, device: &Device) {
|
||||
let b = 1;
|
||||
let m = 2048;
|
||||
let k = 2048;
|
||||
|
||||
let a = Tensor::rand(-1000.0f32, 1000.0f32, (b, m, k), &device).unwrap();
|
||||
|
||||
let flops = b * m * k * DType::F32.size_in_bytes();
|
||||
|
||||
let mut group = c.benchmark_group(bench_name("arg_reduce"));
|
||||
group.throughput(Throughput::Bytes(flops as u64));
|
||||
group.bench_function("iter", move |b| {
|
||||
b.iter_custom(|iters| {
|
||||
let start = Instant::now();
|
||||
for _i in 0..iters {
|
||||
run_arg_min(black_box(&a));
|
||||
}
|
||||
device.sync().unwrap();
|
||||
start.elapsed()
|
||||
|
@ -511,59 +511,56 @@ impl BackendStorage for MetalStorage {
|
||||
(ReduceOp::Sum, DType::F32) => ("fast_sum_f32", false, false),
|
||||
(ReduceOp::Min, DType::F32) => ("fast_min_f32", true, false),
|
||||
(ReduceOp::Max, DType::F32) => ("fast_max_f32", true, false),
|
||||
//(ReduceOp::ArgMin, DType::F32) => ("fast_argmin_f32", true, true),
|
||||
//(ReduceOp::ArgMax, DType::F32) => ("fast_argmax_f32", true, true),
|
||||
(ReduceOp::ArgMin, DType::F32) => ("fast_argmin_f32", true, true),
|
||||
(ReduceOp::ArgMax, DType::F32) => ("fast_argmax_f32", true, true),
|
||||
(ReduceOp::Sum, DType::U32) => ("fast_sum_u32", false, false),
|
||||
(ReduceOp::Min, DType::U32) => ("fast_min_u32", true, false),
|
||||
(ReduceOp::Max, DType::U32) => ("fast_max_u32", true, false),
|
||||
//(ReduceOp::ArgMin, DType::U32) => ("fast_argmin_u32", true, true),
|
||||
//(ReduceOp::ArgMax, DType::U32) => ("fast_argmax_u32", true, true),
|
||||
(ReduceOp::ArgMin, DType::U32) => ("fast_argmin_u32", true, true),
|
||||
(ReduceOp::ArgMax, DType::U32) => ("fast_argmax_u32", true, true),
|
||||
(ReduceOp::Sum, DType::F16) => ("fast_sum_f16", false, false),
|
||||
(ReduceOp::Min, DType::F16) => ("fast_min_f16", true, false),
|
||||
(ReduceOp::Max, DType::F16) => ("fast_max_f16", true, false),
|
||||
//(ReduceOp::ArgMin, DType::F16) => ("fast_argmin_f16", true, true),
|
||||
//(ReduceOp::ArgMax, DType::F16) => ("fast_argmax_f16", true, true),
|
||||
(ReduceOp::ArgMin, DType::F16) => ("fast_argmin_f16", true, true),
|
||||
(ReduceOp::ArgMax, DType::F16) => ("fast_argmax_f16", true, true),
|
||||
(ReduceOp::Sum, DType::BF16) => ("fast_sum_bf16", false, false),
|
||||
(ReduceOp::Min, DType::BF16) => ("fast_min_bf16", true, false),
|
||||
(ReduceOp::Max, DType::BF16) => ("fast_max_bf16", true, false),
|
||||
//(ReduceOp::ArgMin, DType::BF16) => ("fast_argmin_bf16", true, true),
|
||||
//(ReduceOp::ArgMax, DType::BF16) => ("fast_argmax_bf16", true, true),
|
||||
(ReduceOp::ArgMin, DType::BF16) => ("fast_argmin_bf16", true, true),
|
||||
(ReduceOp::ArgMax, DType::BF16) => ("fast_argmax_bf16", true, true),
|
||||
(ReduceOp::Sum, DType::I64) => ("fast_sum_i64", false, false),
|
||||
(ReduceOp::Min, DType::I64) => ("fast_min_i64", true, false),
|
||||
(ReduceOp::Max, DType::I64) => ("fast_max_i64", true, false),
|
||||
//(ReduceOp::ArgMin, DType::I64) => ("fast_argmin_i64", true, true),
|
||||
//(ReduceOp::ArgMax, DType::I64) => ("fast_argmax_i64", true, true),
|
||||
(ReduceOp::ArgMin, DType::I64) => ("fast_argmin_i64", true, true),
|
||||
(ReduceOp::ArgMax, DType::I64) => ("fast_argmax_i64", true, true),
|
||||
(ReduceOp::Sum, DType::U8) => ("fast_sum_u8", false, false),
|
||||
(ReduceOp::Min, DType::U8) => ("fast_min_u8", true, false),
|
||||
(ReduceOp::Max, DType::U8) => ("fast_max_u8", true, false),
|
||||
//(ReduceOp::ArgMin, DType::U8) => ("fast_argmin_u8", true, true),
|
||||
//(ReduceOp::ArgMax, DType::U8) => ("fast_argmax_u8", true, true),
|
||||
//(k, dtype) => crate::bail!("Metal reduce op {k:?} {dtype:?} not implemented"),
|
||||
_ => ("fall back to strided impl", false, false)
|
||||
};
|
||||
|
||||
if name != "fall back to strided impl" {
|
||||
if check_empty && layout.shape().elem_count() == 0 {
|
||||
Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
|
||||
(ReduceOp::ArgMin, DType::U8) => ("fast_argmin_u8", true, true),
|
||||
(ReduceOp::ArgMax, DType::U8) => ("fast_argmax_u8", true, true),
|
||||
(k, dtype) => {
|
||||
crate::bail!("Metal contiguous reduce op {k:?} {dtype:?} not implemented")
|
||||
}
|
||||
|
||||
|
||||
let buffer = device.new_buffer(1, self.dtype, "reduce")?;
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
candle_metal_kernels::call_reduce_contiguous(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
name,
|
||||
layout.shape().elem_count(),
|
||||
dst_el,
|
||||
&self.buffer,
|
||||
layout.start_offset() * self.dtype.size_in_bytes(),
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
return Ok(Self::new(buffer, device, self.dtype));
|
||||
};
|
||||
if check_empty && layout.shape().elem_count() == 0 {
|
||||
Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
|
||||
}
|
||||
|
||||
let buffer = device.new_buffer(1, self.dtype, "reduce")?;
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
candle_metal_kernels::call_reduce_contiguous(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
name,
|
||||
layout.shape().elem_count(),
|
||||
dst_el,
|
||||
&self.buffer,
|
||||
layout.start_offset() * self.dtype.size_in_bytes(),
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
return Ok(Self::new(buffer, device, self.dtype));
|
||||
}
|
||||
|
||||
for &dim_idx in sum_dims.iter() {
|
||||
@ -602,7 +599,7 @@ impl BackendStorage for MetalStorage {
|
||||
(ReduceOp::Max, DType::U8) => ("fast_max_u8_strided", true, false),
|
||||
(ReduceOp::ArgMin, DType::U8) => ("fast_argmin_u8_strided", true, true),
|
||||
(ReduceOp::ArgMax, DType::U8) => ("fast_argmax_u8_strided", true, true),
|
||||
(k, dtype) => crate::bail!("Metal reduce op {k:?} {dtype:?} not implemented"),
|
||||
(k, dtype) => crate::bail!("Metal strided reduce op {k:?} {dtype:?} not implemented"),
|
||||
};
|
||||
if check_empty && layout.shape().elem_count() == 0 {
|
||||
Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
|
||||
|
@ -19,24 +19,24 @@ METAL_FUNC uint get_strided_index(
|
||||
}
|
||||
|
||||
#define impl_reduction_op(name, op, init_val) \
|
||||
template<typename T> \
|
||||
template<typename T, typename R = T> \
|
||||
struct name { \
|
||||
\
|
||||
static constexpr constant T init = init_val; \
|
||||
\
|
||||
METAL_FUNC T operator()(thread const T &a, thread const T &b) const { \
|
||||
METAL_FUNC R operator()(thread const T &a, thread const T &b) const { \
|
||||
return op; \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC T operator()(threadgroup const T &a, threadgroup const T &b) const { \
|
||||
METAL_FUNC R operator()(threadgroup const T &a, threadgroup const T &b) const { \
|
||||
return op; \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC T operator()(device const T &a, device const T &b) const { \
|
||||
METAL_FUNC R operator()(device const T &a, device const T &b) const { \
|
||||
return op; \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC T operator()(T a, T b) { \
|
||||
METAL_FUNC R operator()(T a, T b) { \
|
||||
return op; \
|
||||
} \
|
||||
} \
|
||||
@ -45,10 +45,13 @@ impl_reduction_op(Sum, a + b, 0);
|
||||
impl_reduction_op(Mul, a * b, 1);
|
||||
impl_reduction_op(Min, a < b ? a : b, numeric_limits<T>::max());
|
||||
impl_reduction_op(Max, a > b ? a : b, numeric_limits<T>::min());
|
||||
impl_reduction_op(ArgMin, a < b, numeric_limits<T>::max());
|
||||
impl_reduction_op(ArgMax, a > b, numeric_limits<T>::min());
|
||||
#undef impl_reduction_op
|
||||
|
||||
static constant constexpr int THREADGROUP_SIZE = 2048;
|
||||
|
||||
// Load strided elements from global memory into shared memory.
|
||||
template<typename T, typename ReductionOp, uint BLOCKSIZE>
|
||||
METAL_FUNC void load_from_global(
|
||||
constant size_t &num_dims,
|
||||
@ -74,6 +77,40 @@ METAL_FUNC void load_from_global(
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
}
|
||||
|
||||
// Load strided elements from global memory into shared memory with indices.
|
||||
template<typename T, typename ArgReductionOp, uint BLOCKSIZE>
|
||||
METAL_FUNC void load_from_global(
|
||||
constant size_t &num_dims,
|
||||
constant size_t *dims,
|
||||
constant size_t *strides,
|
||||
constant size_t &el_to_sum_per_block,
|
||||
device const T *src,
|
||||
threadgroup T shared[BLOCKSIZE],
|
||||
threadgroup uint shared_indices[BLOCKSIZE],
|
||||
uint tid [[ thread_index_in_threadgroup ]],
|
||||
uint dst_id [[ threadgroup_position_in_grid ]],
|
||||
uint block_dim [[ threads_per_threadgroup ]]
|
||||
) {
|
||||
ArgReductionOp op;
|
||||
bool notset = true;
|
||||
|
||||
size_t start_idx = dst_id * el_to_sum_per_block;
|
||||
size_t stop_idx = start_idx + el_to_sum_per_block;
|
||||
size_t idx = start_idx + tid;
|
||||
while (idx < stop_idx) {
|
||||
size_t strided_i = get_strided_index(idx, num_dims, dims, strides);
|
||||
if (notset || op(src[strided_i], shared[tid])) {
|
||||
shared[tid] = src[strided_i];
|
||||
// Assume that the reduction takes place over the last dimension which is contiguous.
|
||||
shared_indices[tid] = idx % dims[num_dims - 1];
|
||||
notset = false;
|
||||
}
|
||||
idx += block_dim;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
}
|
||||
|
||||
// Load contiguous elements from global memory into shared memory.
|
||||
template<typename T, typename ReductionOp, uint BLOCKSIZE>
|
||||
METAL_FUNC void load_from_global(
|
||||
constant size_t &num_dims,
|
||||
@ -97,6 +134,45 @@ METAL_FUNC void load_from_global(
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
}
|
||||
|
||||
// Load contiguous elements from global memory into shared memory with indices.
|
||||
template<typename T, typename ArgReductionOp, uint BLOCKSIZE>
|
||||
METAL_FUNC void load_from_global(
|
||||
constant size_t &num_dims,
|
||||
constant size_t *dims,
|
||||
constant size_t &el_to_sum_per_block,
|
||||
device const T *src,
|
||||
threadgroup T shared[BLOCKSIZE],
|
||||
threadgroup uint shared_indices[BLOCKSIZE],
|
||||
uint tid [[ thread_index_in_threadgroup ]],
|
||||
uint dst_id [[ threadgroup_position_in_grid ]],
|
||||
uint block_dim [[ threads_per_threadgroup ]]
|
||||
) {
|
||||
ArgReductionOp op;
|
||||
bool notset = true;
|
||||
|
||||
size_t start_idx = dst_id * el_to_sum_per_block;
|
||||
size_t stop_idx = start_idx + el_to_sum_per_block;
|
||||
size_t idx = start_idx + tid;
|
||||
while (idx < stop_idx) {
|
||||
if (notset || op(src[idx], shared[tid])) {
|
||||
shared[tid] = src[idx];
|
||||
// Assume that the reduction takes place over the last dimension which is contiguous.
|
||||
shared_indices[tid] = idx % dims[num_dims - 1];
|
||||
notset = false;
|
||||
}
|
||||
idx += block_dim;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
}
|
||||
|
||||
#define reduce_threadgroup(SIZE) \
|
||||
if (BLOCKSIZE >= SIZE) { \
|
||||
if (block_dim >= SIZE) { \
|
||||
shared[tid] = op(shared[tid], shared[tid + SIZE / 2]); \
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
} \
|
||||
}
|
||||
|
||||
template<typename T, typename ReductionOp, uint BLOCKSIZE>
|
||||
METAL_FUNC void threadgroup_reduce(
|
||||
threadgroup T shared[BLOCKSIZE],
|
||||
@ -104,37 +180,50 @@ METAL_FUNC void threadgroup_reduce(
|
||||
uint block_dim [[ threads_per_threadgroup ]]
|
||||
) {
|
||||
ReductionOp op;
|
||||
if (BLOCKSIZE >= 64) {
|
||||
if (block_dim >= 64) {
|
||||
shared[tid] = op(shared[tid], shared[tid + 32]);
|
||||
}
|
||||
}
|
||||
if (BLOCKSIZE >= 32) {
|
||||
if (block_dim >= 32) {
|
||||
shared[tid] = op(shared[tid], shared[tid + 16]);
|
||||
}
|
||||
}
|
||||
if (BLOCKSIZE >= 16) {
|
||||
if (block_dim >= 16) {
|
||||
shared[tid] = op(shared[tid], shared[tid + 8]);
|
||||
}
|
||||
}
|
||||
if (BLOCKSIZE >= 8) {
|
||||
if (block_dim >= 8) {
|
||||
shared[tid] = op(shared[tid], shared[tid + 4]);
|
||||
}
|
||||
}
|
||||
if (BLOCKSIZE >= 4) {
|
||||
if (block_dim >= 4) {
|
||||
shared[tid] = op(shared[tid], shared[tid + 2]);
|
||||
}
|
||||
}
|
||||
if (BLOCKSIZE >= 2) {
|
||||
if (block_dim >= 2) {
|
||||
shared[tid] = op(shared[tid], shared[tid + 1]);
|
||||
}
|
||||
}
|
||||
reduce_threadgroup(64);
|
||||
reduce_threadgroup(32);
|
||||
reduce_threadgroup(16);
|
||||
reduce_threadgroup(8);
|
||||
reduce_threadgroup(4);
|
||||
reduce_threadgroup(2);
|
||||
}
|
||||
#undef reduce_threadgroup
|
||||
|
||||
#define arg_reduce_threadgroup(SIZE) \
|
||||
if (BLOCKSIZE >= SIZE) { \
|
||||
if (block_dim >= SIZE && \
|
||||
op(shared[tid], shared[tid + SIZE / 2]) \
|
||||
) { \
|
||||
shared_indices[tid] = shared_indices[tid + SIZE / 2]; \
|
||||
shared[tid] = shared[tid + SIZE / 2]; \
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
} \
|
||||
}
|
||||
|
||||
template<typename T, typename ArgReductionOp, uint BLOCKSIZE>
|
||||
METAL_FUNC void threadgroup_reduce(
|
||||
threadgroup T shared[BLOCKSIZE],
|
||||
threadgroup uint shared_indices[BLOCKSIZE],
|
||||
uint tid [[thread_index_in_threadgroup]],
|
||||
uint block_dim [[ threads_per_threadgroup ]]
|
||||
) {
|
||||
ArgReductionOp op;
|
||||
arg_reduce_threadgroup(64);
|
||||
arg_reduce_threadgroup(32);
|
||||
arg_reduce_threadgroup(16);
|
||||
arg_reduce_threadgroup(8);
|
||||
arg_reduce_threadgroup(4);
|
||||
arg_reduce_threadgroup(2);
|
||||
}
|
||||
#undef arg_reduce_threadgroup
|
||||
|
||||
#define reduce_block(SIZE) \
|
||||
if (BLOCKSIZE >= SIZE) { \
|
||||
if (tid < SIZE / 2 && block_dim >= SIZE) { \
|
||||
shared[tid] = op(shared[tid], shared[tid + SIZE / 2]); \
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
} \
|
||||
} \
|
||||
|
||||
// Inspired by "Optimizing Parallel Reduction in CUDA" by Mark Harris
|
||||
template<
|
||||
@ -186,42 +275,20 @@ METAL_FUNC void block_reduce(
|
||||
);
|
||||
}
|
||||
|
||||
if (BLOCKSIZE >= 1024) {
|
||||
if (tid < 512 && block_dim >= 1024) {
|
||||
shared[tid] = op(shared[tid], shared[tid + 512]);
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
}
|
||||
}
|
||||
if (BLOCKSIZE >= 512) {
|
||||
if (tid < 256 && block_dim >= 512) {
|
||||
shared[tid] = op(shared[tid], shared[tid + 256]);
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
}
|
||||
}
|
||||
if (BLOCKSIZE >= 256) {
|
||||
if (tid < 128 && block_dim >= 256) {
|
||||
shared[tid] = op(shared[tid], shared[tid + 128]);
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
}
|
||||
}
|
||||
if (BLOCKSIZE >= 128) {
|
||||
if (tid < 64 && block_dim >= 128) {
|
||||
shared[tid] = op(shared[tid], shared[tid + 64]);
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
}
|
||||
}
|
||||
reduce_block(1024);
|
||||
reduce_block(512);
|
||||
reduce_block(256);
|
||||
reduce_block(128);
|
||||
|
||||
if (tid < 32) {
|
||||
threadgroup_reduce<T, ReductionOp, BLOCKSIZE>(shared, tid, block_dim);
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
}
|
||||
|
||||
if (tid == 0) {
|
||||
dst[dst_id] = shared[tid];
|
||||
}
|
||||
}
|
||||
|
||||
#define MAX(x, y) ((x) > (y) ? (x) : (y))
|
||||
#define MIN(x, y) ((x) < (y) ? (x) : (y))
|
||||
#undef reduce_block
|
||||
|
||||
static constant constexpr int BLOCKSIZE = 2048;
|
||||
|
||||
@ -283,123 +350,146 @@ kernel void NAME##_strided( \
|
||||
block_dim); \
|
||||
} \
|
||||
|
||||
#define ARGMIN(NAME, T, MAXVALUE) \
|
||||
kernel void NAME( \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *strides, \
|
||||
constant size_t &el_to_sum_per_block, \
|
||||
device const T *src, \
|
||||
device uint *dst, \
|
||||
uint id [[ thread_position_in_grid ]], \
|
||||
uint tid [[ thread_index_in_threadgroup ]], \
|
||||
uint dst_id [[ threadgroup_position_in_grid ]], \
|
||||
uint block_dim [[ threads_per_threadgroup ]] \
|
||||
) { \
|
||||
\
|
||||
threadgroup T shared_memory[THREADGROUP_SIZE]; \
|
||||
threadgroup uint shared_indices[THREADGROUP_SIZE]; \
|
||||
\
|
||||
shared_memory[tid] = MAXVALUE; \
|
||||
shared_indices[tid] = 0xFFFFFFFF; \
|
||||
bool notset = true; \
|
||||
/* \
|
||||
// Elements summed in this block range from dst_id * el_to_sum_per_block \
|
||||
// to (dst_id + 1) * el_to_sum_per_block. \
|
||||
*/ \
|
||||
size_t start_idx = dst_id * el_to_sum_per_block; \
|
||||
size_t stop_idx = start_idx + el_to_sum_per_block; \
|
||||
size_t idx = start_idx + tid; \
|
||||
while (idx < stop_idx) { \
|
||||
/* \
|
||||
// TODO: Fast version for the contiguous case. \
|
||||
*/ \
|
||||
size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
|
||||
if (notset || src[strided_i] < shared_memory[tid]) { \
|
||||
shared_memory[tid] = src[strided_i]; \
|
||||
/* Assume that the reduction takes place over the last dimension which is contiguous. */ \
|
||||
shared_indices[tid] = idx % dims[num_dims - 1]; \
|
||||
notset = false; \
|
||||
} \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
\
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
\
|
||||
/* \
|
||||
// reduction in shared memory \
|
||||
*/ \
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
||||
if (tid < s && shared_memory[tid + s] < shared_memory[tid]) { \
|
||||
shared_indices[tid] = shared_indices[tid + s]; \
|
||||
shared_memory[tid] = shared_memory[tid + s]; \
|
||||
} \
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
} \
|
||||
\
|
||||
if (tid == 0){ \
|
||||
dst[dst_id] = shared_indices[0]; \
|
||||
} \
|
||||
} \
|
||||
#define arg_reduce_block(SIZE) \
|
||||
if (BLOCKSIZE >= SIZE) { \
|
||||
if (tid < SIZE / 2 \
|
||||
&& block_dim >= SIZE \
|
||||
&& arg_op(shared[tid], shared[tid + SIZE / 2]) \
|
||||
) { \
|
||||
shared_indices[tid] = shared_indices[tid + SIZE / 2]; \
|
||||
shared[tid] = shared[tid + SIZE / 2]; \
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
} \
|
||||
} \
|
||||
|
||||
template<
|
||||
typename T,
|
||||
typename ArgReductionOp,
|
||||
uint BLOCKSIZE,
|
||||
bool STRIDED
|
||||
>
|
||||
METAL_FUNC void arg_block_reduce(
|
||||
constant size_t &num_dims,
|
||||
constant size_t *dims,
|
||||
constant size_t *strides,
|
||||
constant size_t &el_to_sum_per_block,
|
||||
device const T *src,
|
||||
device uint *dst,
|
||||
threadgroup T shared[BLOCKSIZE],
|
||||
threadgroup uint shared_indices[BLOCKSIZE],
|
||||
uint id [[ thread_position_in_grid ]],
|
||||
uint tid [[ thread_index_in_threadgroup ]],
|
||||
uint dst_id [[ threadgroup_position_in_grid ]],
|
||||
uint block_dim [[ threads_per_threadgroup ]]
|
||||
) {
|
||||
ArgReductionOp arg_op;
|
||||
|
||||
shared[tid] = ArgReductionOp::init;
|
||||
shared_indices[tid] = numeric_limits<uint>::max();
|
||||
|
||||
if (STRIDED) {
|
||||
load_from_global<T, ArgReductionOp, BLOCKSIZE>(
|
||||
num_dims,
|
||||
dims,
|
||||
strides,
|
||||
el_to_sum_per_block,
|
||||
src,
|
||||
shared,
|
||||
shared_indices,
|
||||
tid,
|
||||
dst_id,
|
||||
block_dim
|
||||
);
|
||||
} else {
|
||||
load_from_global<T, ArgReductionOp, BLOCKSIZE>(
|
||||
num_dims,
|
||||
dims,
|
||||
el_to_sum_per_block,
|
||||
src,
|
||||
shared,
|
||||
shared_indices,
|
||||
tid,
|
||||
dst_id,
|
||||
block_dim
|
||||
);
|
||||
}
|
||||
arg_reduce_block(1024);
|
||||
arg_reduce_block(512);
|
||||
arg_reduce_block(256);
|
||||
arg_reduce_block(128);
|
||||
|
||||
if (tid < 32) {
|
||||
threadgroup_reduce<T, ArgReductionOp, BLOCKSIZE>(shared, shared_indices, tid, block_dim);
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
}
|
||||
|
||||
if (tid == 0) {
|
||||
dst[dst_id] = shared_indices[0];
|
||||
}
|
||||
}
|
||||
#undef arg_reduce_block
|
||||
|
||||
#define ARG_REDUCE(OP, NAME, T) \
|
||||
kernel void NAME( \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *strides, \
|
||||
constant size_t &el_to_sum_per_block, \
|
||||
device const T *src, \
|
||||
device uint *dst, \
|
||||
uint id [[ thread_position_in_grid ]], \
|
||||
uint tid [[ thread_index_in_threadgroup ]], \
|
||||
uint dst_id [[ threadgroup_position_in_grid ]], \
|
||||
uint block_dim [[ threads_per_threadgroup ]] \
|
||||
) { \
|
||||
threadgroup T shared[BLOCKSIZE]; \
|
||||
threadgroup uint shared_indices[BLOCKSIZE]; \
|
||||
arg_block_reduce<T, OP<T, bool>, BLOCKSIZE, false>( \
|
||||
num_dims, \
|
||||
dims, \
|
||||
strides, \
|
||||
el_to_sum_per_block, \
|
||||
src, \
|
||||
dst, \
|
||||
shared, \
|
||||
shared_indices, \
|
||||
id, \
|
||||
tid, \
|
||||
dst_id, \
|
||||
block_dim); \
|
||||
} \
|
||||
kernel void NAME##_strided( \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *strides, \
|
||||
constant size_t &el_to_sum_per_block, \
|
||||
device const T *src, \
|
||||
device uint *dst, \
|
||||
uint id [[ thread_position_in_grid ]], \
|
||||
uint tid [[ thread_index_in_threadgroup ]], \
|
||||
uint dst_id [[ threadgroup_position_in_grid ]], \
|
||||
uint block_dim [[ threads_per_threadgroup ]] \
|
||||
) { \
|
||||
threadgroup T shared[BLOCKSIZE]; \
|
||||
threadgroup uint shared_indices[BLOCKSIZE]; \
|
||||
arg_block_reduce<T, OP<T, bool>, BLOCKSIZE, true>( \
|
||||
num_dims, \
|
||||
dims, \
|
||||
strides, \
|
||||
el_to_sum_per_block, \
|
||||
src, \
|
||||
dst, \
|
||||
shared, \
|
||||
shared_indices, \
|
||||
id, \
|
||||
tid, \
|
||||
dst_id, \
|
||||
block_dim); \
|
||||
}
|
||||
|
||||
|
||||
#define ARGMAX(NAME, T, MINVALUE) \
|
||||
kernel void NAME( \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *strides, \
|
||||
constant size_t &el_to_sum_per_block, \
|
||||
device const T *src, \
|
||||
device uint *dst, \
|
||||
uint id [[ thread_position_in_grid ]], \
|
||||
uint tid [[ thread_index_in_threadgroup ]], \
|
||||
uint dst_id [[ threadgroup_position_in_grid ]], \
|
||||
uint block_dim [[ threads_per_threadgroup ]] \
|
||||
) { \
|
||||
\
|
||||
threadgroup T shared_memory[THREADGROUP_SIZE]; \
|
||||
threadgroup uint shared_indices[THREADGROUP_SIZE]; \
|
||||
\
|
||||
shared_memory[tid] = MINVALUE; \
|
||||
shared_indices[tid] = 0xFFFFFFFF; \
|
||||
/* \
|
||||
// Elements summed in this block range from dst_id * el_to_sum_per_block \
|
||||
// to (dst_id + 1) * el_to_sum_per_block. \
|
||||
*/ \
|
||||
size_t start_idx = dst_id * el_to_sum_per_block; \
|
||||
size_t stop_idx = start_idx + el_to_sum_per_block; \
|
||||
size_t idx = start_idx + tid; \
|
||||
bool notset = true; \
|
||||
while (idx < stop_idx) { \
|
||||
/* \
|
||||
// TODO: Fast version for the contiguous case. \
|
||||
*/ \
|
||||
size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
|
||||
if (notset || shared_memory[tid] < src[strided_i]) { \
|
||||
shared_memory[tid] = src[strided_i]; \
|
||||
shared_indices[tid] = idx % dims[num_dims - 1]; \
|
||||
notset = false; \
|
||||
} \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
\
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
\
|
||||
/* \
|
||||
// reduction in shared memory \
|
||||
*/ \
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
||||
if (tid < s && shared_memory[tid + s] > shared_memory[tid]) { \
|
||||
shared_indices[tid] = shared_indices[tid + s]; \
|
||||
shared_memory[tid] = shared_memory[tid + s]; \
|
||||
} \
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
} \
|
||||
\
|
||||
if (tid == 0){ \
|
||||
dst[dst_id] = shared_indices[0]; \
|
||||
} \
|
||||
} \
|
||||
#define MAX(x, y) ((x) > (y) ? (x) : (y))
|
||||
#define MIN(x, y) ((x) < (y) ? (x) : (y))
|
||||
|
||||
#define SOFTMAX(NAME, T) \
|
||||
kernel void NAME( \
|
||||
@ -472,26 +562,31 @@ REDUCE(Sum, fast_sum_f32, float)
|
||||
REDUCE(Sum, fast_sum_u32, uint)
|
||||
REDUCE(Sum, fast_sum_f16, half)
|
||||
REDUCE(Sum, fast_sum_u8, uint8_t)
|
||||
|
||||
REDUCE(Mul, fast_mul_f32, float)
|
||||
REDUCE(Mul, fast_mul_u32, uint)
|
||||
REDUCE(Mul, fast_mul_f16, half)
|
||||
REDUCE(Mul, fast_mul_u8, uint8_t)
|
||||
|
||||
REDUCE(Max, fast_max_f32, float)
|
||||
REDUCE(Max, fast_max_u32, uint)
|
||||
REDUCE(Max, fast_max_f16, half)
|
||||
REDUCE(Max, fast_max_u8, uint8_t)
|
||||
|
||||
REDUCE(Min, fast_min_f32, float)
|
||||
REDUCE(Min, fast_min_u32, uint)
|
||||
REDUCE(Min, fast_min_f16, half)
|
||||
REDUCE(Min, fast_min_u8, uint8_t)
|
||||
|
||||
ARGMIN(fast_argmin_f32_strided, float, HUGE_VALF)
|
||||
ARGMIN(fast_argmin_f16_strided, half, HUGE_VALH)
|
||||
ARGMIN(fast_argmin_u32_strided, uint, 0xFFFFFFFF)
|
||||
ARGMIN(fast_argmin_u8_strided, uint8_t, 0xFF)
|
||||
ARGMAX(fast_argmax_f32_strided, float, -HUGE_VALF)
|
||||
ARGMAX(fast_argmax_f16_strided, half, -HUGE_VALH)
|
||||
ARGMAX(fast_argmax_u32_strided, uint, 0)
|
||||
ARGMAX(fast_argmax_u8_strided, uint8_t, 0)
|
||||
ARG_REDUCE(ArgMin, fast_argmin_f32, float)
|
||||
ARG_REDUCE(ArgMin, fast_argmin_f16, half)
|
||||
ARG_REDUCE(ArgMin, fast_argmin_u32, uint)
|
||||
ARG_REDUCE(ArgMin, fast_argmin_u8, uint8_t)
|
||||
|
||||
ARG_REDUCE(ArgMax, fast_argmax_f32, float)
|
||||
ARG_REDUCE(ArgMax, fast_argmax_f16, half)
|
||||
ARG_REDUCE(ArgMax, fast_argmax_u32, uint)
|
||||
ARG_REDUCE(ArgMax, fast_argmax_u8, uint8_t)
|
||||
|
||||
SOFTMAX(softmax_f32, float)
|
||||
SOFTMAX(softmax_f16, half)
|
||||
@ -502,8 +597,9 @@ REDUCE(Mul, fast_mul_i64, int64_t)
|
||||
REDUCE(Min, fast_min_i64, int64_t)
|
||||
REDUCE(Max, fast_max_i64, int64_t)
|
||||
|
||||
ARGMIN(fast_argmin_i64_strided, int64_t, INT_MAX)
|
||||
ARGMAX(fast_argmax_i64_strided, int64_t, INT_MIN)
|
||||
ARG_REDUCE(ArgMin, fast_argmin_i64, int64_t)
|
||||
ARG_REDUCE(ArgMax, fast_argmax_i64, int64_t)
|
||||
|
||||
#endif
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
@ -512,7 +608,8 @@ REDUCE(Mul, fast_mul_bf16, bfloat)
|
||||
REDUCE(Max, fast_max_bf16, bfloat)
|
||||
REDUCE(Min, fast_min_bf16, bfloat)
|
||||
|
||||
ARGMIN(fast_argmin_bf16, bfloat, HUGE_VALBF)
|
||||
ARGMAX(fast_argmax_bf16, bfloat, -HUGE_VALBF)
|
||||
ARG_REDUCE(ArgMin, fast_argmin_bf16, bfloat)
|
||||
ARG_REDUCE(ArgMax, fast_argmax_bf16, bfloat)
|
||||
|
||||
SOFTMAX(softmax_bf16, bfloat)
|
||||
#endif
|
||||
|
Reference in New Issue
Block a user