mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
Improve arg reduce and add contiguous impl
This commit is contained in:
@ -1,19 +1,25 @@
|
|||||||
use candle_core::{DType, Tensor};
|
use crate::benchmarks::{bench_name, device, BenchDevice};
|
||||||
|
use candle_core::{DType, Device, Tensor};
|
||||||
use criterion::{black_box, criterion_group, Criterion, Throughput};
|
use criterion::{black_box, criterion_group, Criterion, Throughput};
|
||||||
use std::time::Instant;
|
use std::time::Instant;
|
||||||
use crate::benchmarks::{bench_name, device, BenchDevice};
|
|
||||||
|
|
||||||
fn run(a: &Tensor) {
|
fn run_sum(a: &Tensor) {
|
||||||
a.sum(2).unwrap();
|
a.sum(2).unwrap();
|
||||||
}
|
}
|
||||||
|
fn run_arg_min(a: &Tensor) {
|
||||||
|
a.argmin(2).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
fn criterion_benchmark(c: &mut Criterion) {
|
fn criterion_benchmark(c: &mut Criterion) {
|
||||||
|
let device = device().unwrap();
|
||||||
|
run_reduce(c, &device);
|
||||||
|
run_arg_reduce(c, &device);
|
||||||
|
}
|
||||||
|
fn run_reduce(c: &mut Criterion, device: &Device) {
|
||||||
let b = 1;
|
let b = 1;
|
||||||
let m = 2048;
|
let m = 2048;
|
||||||
let k = 2048;
|
let k = 2048;
|
||||||
|
|
||||||
let device = device().unwrap();
|
|
||||||
|
|
||||||
let a = Tensor::rand(-1000.0f32, 1000.0f32, (b, m, k), &device).unwrap();
|
let a = Tensor::rand(-1000.0f32, 1000.0f32, (b, m, k), &device).unwrap();
|
||||||
|
|
||||||
let flops = b * m * k * DType::F32.size_in_bytes();
|
let flops = b * m * k * DType::F32.size_in_bytes();
|
||||||
@ -24,7 +30,31 @@ fn criterion_benchmark(c: &mut Criterion) {
|
|||||||
b.iter_custom(|iters| {
|
b.iter_custom(|iters| {
|
||||||
let start = Instant::now();
|
let start = Instant::now();
|
||||||
for _i in 0..iters {
|
for _i in 0..iters {
|
||||||
run(black_box(&a));
|
run_sum(black_box(&a));
|
||||||
|
}
|
||||||
|
device.sync().unwrap();
|
||||||
|
start.elapsed()
|
||||||
|
})
|
||||||
|
});
|
||||||
|
group.finish();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_arg_reduce(c: &mut Criterion, device: &Device) {
|
||||||
|
let b = 1;
|
||||||
|
let m = 2048;
|
||||||
|
let k = 2048;
|
||||||
|
|
||||||
|
let a = Tensor::rand(-1000.0f32, 1000.0f32, (b, m, k), &device).unwrap();
|
||||||
|
|
||||||
|
let flops = b * m * k * DType::F32.size_in_bytes();
|
||||||
|
|
||||||
|
let mut group = c.benchmark_group(bench_name("arg_reduce"));
|
||||||
|
group.throughput(Throughput::Bytes(flops as u64));
|
||||||
|
group.bench_function("iter", move |b| {
|
||||||
|
b.iter_custom(|iters| {
|
||||||
|
let start = Instant::now();
|
||||||
|
for _i in 0..iters {
|
||||||
|
run_arg_min(black_box(&a));
|
||||||
}
|
}
|
||||||
device.sync().unwrap();
|
device.sync().unwrap();
|
||||||
start.elapsed()
|
start.elapsed()
|
||||||
|
@ -511,43 +511,41 @@ impl BackendStorage for MetalStorage {
|
|||||||
(ReduceOp::Sum, DType::F32) => ("fast_sum_f32", false, false),
|
(ReduceOp::Sum, DType::F32) => ("fast_sum_f32", false, false),
|
||||||
(ReduceOp::Min, DType::F32) => ("fast_min_f32", true, false),
|
(ReduceOp::Min, DType::F32) => ("fast_min_f32", true, false),
|
||||||
(ReduceOp::Max, DType::F32) => ("fast_max_f32", true, false),
|
(ReduceOp::Max, DType::F32) => ("fast_max_f32", true, false),
|
||||||
//(ReduceOp::ArgMin, DType::F32) => ("fast_argmin_f32", true, true),
|
(ReduceOp::ArgMin, DType::F32) => ("fast_argmin_f32", true, true),
|
||||||
//(ReduceOp::ArgMax, DType::F32) => ("fast_argmax_f32", true, true),
|
(ReduceOp::ArgMax, DType::F32) => ("fast_argmax_f32", true, true),
|
||||||
(ReduceOp::Sum, DType::U32) => ("fast_sum_u32", false, false),
|
(ReduceOp::Sum, DType::U32) => ("fast_sum_u32", false, false),
|
||||||
(ReduceOp::Min, DType::U32) => ("fast_min_u32", true, false),
|
(ReduceOp::Min, DType::U32) => ("fast_min_u32", true, false),
|
||||||
(ReduceOp::Max, DType::U32) => ("fast_max_u32", true, false),
|
(ReduceOp::Max, DType::U32) => ("fast_max_u32", true, false),
|
||||||
//(ReduceOp::ArgMin, DType::U32) => ("fast_argmin_u32", true, true),
|
(ReduceOp::ArgMin, DType::U32) => ("fast_argmin_u32", true, true),
|
||||||
//(ReduceOp::ArgMax, DType::U32) => ("fast_argmax_u32", true, true),
|
(ReduceOp::ArgMax, DType::U32) => ("fast_argmax_u32", true, true),
|
||||||
(ReduceOp::Sum, DType::F16) => ("fast_sum_f16", false, false),
|
(ReduceOp::Sum, DType::F16) => ("fast_sum_f16", false, false),
|
||||||
(ReduceOp::Min, DType::F16) => ("fast_min_f16", true, false),
|
(ReduceOp::Min, DType::F16) => ("fast_min_f16", true, false),
|
||||||
(ReduceOp::Max, DType::F16) => ("fast_max_f16", true, false),
|
(ReduceOp::Max, DType::F16) => ("fast_max_f16", true, false),
|
||||||
//(ReduceOp::ArgMin, DType::F16) => ("fast_argmin_f16", true, true),
|
(ReduceOp::ArgMin, DType::F16) => ("fast_argmin_f16", true, true),
|
||||||
//(ReduceOp::ArgMax, DType::F16) => ("fast_argmax_f16", true, true),
|
(ReduceOp::ArgMax, DType::F16) => ("fast_argmax_f16", true, true),
|
||||||
(ReduceOp::Sum, DType::BF16) => ("fast_sum_bf16", false, false),
|
(ReduceOp::Sum, DType::BF16) => ("fast_sum_bf16", false, false),
|
||||||
(ReduceOp::Min, DType::BF16) => ("fast_min_bf16", true, false),
|
(ReduceOp::Min, DType::BF16) => ("fast_min_bf16", true, false),
|
||||||
(ReduceOp::Max, DType::BF16) => ("fast_max_bf16", true, false),
|
(ReduceOp::Max, DType::BF16) => ("fast_max_bf16", true, false),
|
||||||
//(ReduceOp::ArgMin, DType::BF16) => ("fast_argmin_bf16", true, true),
|
(ReduceOp::ArgMin, DType::BF16) => ("fast_argmin_bf16", true, true),
|
||||||
//(ReduceOp::ArgMax, DType::BF16) => ("fast_argmax_bf16", true, true),
|
(ReduceOp::ArgMax, DType::BF16) => ("fast_argmax_bf16", true, true),
|
||||||
(ReduceOp::Sum, DType::I64) => ("fast_sum_i64", false, false),
|
(ReduceOp::Sum, DType::I64) => ("fast_sum_i64", false, false),
|
||||||
(ReduceOp::Min, DType::I64) => ("fast_min_i64", true, false),
|
(ReduceOp::Min, DType::I64) => ("fast_min_i64", true, false),
|
||||||
(ReduceOp::Max, DType::I64) => ("fast_max_i64", true, false),
|
(ReduceOp::Max, DType::I64) => ("fast_max_i64", true, false),
|
||||||
//(ReduceOp::ArgMin, DType::I64) => ("fast_argmin_i64", true, true),
|
(ReduceOp::ArgMin, DType::I64) => ("fast_argmin_i64", true, true),
|
||||||
//(ReduceOp::ArgMax, DType::I64) => ("fast_argmax_i64", true, true),
|
(ReduceOp::ArgMax, DType::I64) => ("fast_argmax_i64", true, true),
|
||||||
(ReduceOp::Sum, DType::U8) => ("fast_sum_u8", false, false),
|
(ReduceOp::Sum, DType::U8) => ("fast_sum_u8", false, false),
|
||||||
(ReduceOp::Min, DType::U8) => ("fast_min_u8", true, false),
|
(ReduceOp::Min, DType::U8) => ("fast_min_u8", true, false),
|
||||||
(ReduceOp::Max, DType::U8) => ("fast_max_u8", true, false),
|
(ReduceOp::Max, DType::U8) => ("fast_max_u8", true, false),
|
||||||
//(ReduceOp::ArgMin, DType::U8) => ("fast_argmin_u8", true, true),
|
(ReduceOp::ArgMin, DType::U8) => ("fast_argmin_u8", true, true),
|
||||||
//(ReduceOp::ArgMax, DType::U8) => ("fast_argmax_u8", true, true),
|
(ReduceOp::ArgMax, DType::U8) => ("fast_argmax_u8", true, true),
|
||||||
//(k, dtype) => crate::bail!("Metal reduce op {k:?} {dtype:?} not implemented"),
|
(k, dtype) => {
|
||||||
_ => ("fall back to strided impl", false, false)
|
crate::bail!("Metal contiguous reduce op {k:?} {dtype:?} not implemented")
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
if name != "fall back to strided impl" {
|
|
||||||
if check_empty && layout.shape().elem_count() == 0 {
|
if check_empty && layout.shape().elem_count() == 0 {
|
||||||
Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
|
Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
let buffer = device.new_buffer(1, self.dtype, "reduce")?;
|
let buffer = device.new_buffer(1, self.dtype, "reduce")?;
|
||||||
let command_buffer = self.device.command_buffer()?;
|
let command_buffer = self.device.command_buffer()?;
|
||||||
candle_metal_kernels::call_reduce_contiguous(
|
candle_metal_kernels::call_reduce_contiguous(
|
||||||
@ -564,7 +562,6 @@ impl BackendStorage for MetalStorage {
|
|||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
return Ok(Self::new(buffer, device, self.dtype));
|
return Ok(Self::new(buffer, device, self.dtype));
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
for &dim_idx in sum_dims.iter() {
|
for &dim_idx in sum_dims.iter() {
|
||||||
dims.push(src_dims[dim_idx]);
|
dims.push(src_dims[dim_idx]);
|
||||||
@ -602,7 +599,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
(ReduceOp::Max, DType::U8) => ("fast_max_u8_strided", true, false),
|
(ReduceOp::Max, DType::U8) => ("fast_max_u8_strided", true, false),
|
||||||
(ReduceOp::ArgMin, DType::U8) => ("fast_argmin_u8_strided", true, true),
|
(ReduceOp::ArgMin, DType::U8) => ("fast_argmin_u8_strided", true, true),
|
||||||
(ReduceOp::ArgMax, DType::U8) => ("fast_argmax_u8_strided", true, true),
|
(ReduceOp::ArgMax, DType::U8) => ("fast_argmax_u8_strided", true, true),
|
||||||
(k, dtype) => crate::bail!("Metal reduce op {k:?} {dtype:?} not implemented"),
|
(k, dtype) => crate::bail!("Metal strided reduce op {k:?} {dtype:?} not implemented"),
|
||||||
};
|
};
|
||||||
if check_empty && layout.shape().elem_count() == 0 {
|
if check_empty && layout.shape().elem_count() == 0 {
|
||||||
Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
|
Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
|
||||||
|
@ -19,24 +19,24 @@ METAL_FUNC uint get_strided_index(
|
|||||||
}
|
}
|
||||||
|
|
||||||
#define impl_reduction_op(name, op, init_val) \
|
#define impl_reduction_op(name, op, init_val) \
|
||||||
template<typename T> \
|
template<typename T, typename R = T> \
|
||||||
struct name { \
|
struct name { \
|
||||||
\
|
\
|
||||||
static constexpr constant T init = init_val; \
|
static constexpr constant T init = init_val; \
|
||||||
\
|
\
|
||||||
METAL_FUNC T operator()(thread const T &a, thread const T &b) const { \
|
METAL_FUNC R operator()(thread const T &a, thread const T &b) const { \
|
||||||
return op; \
|
return op; \
|
||||||
} \
|
} \
|
||||||
\
|
\
|
||||||
METAL_FUNC T operator()(threadgroup const T &a, threadgroup const T &b) const { \
|
METAL_FUNC R operator()(threadgroup const T &a, threadgroup const T &b) const { \
|
||||||
return op; \
|
return op; \
|
||||||
} \
|
} \
|
||||||
\
|
\
|
||||||
METAL_FUNC T operator()(device const T &a, device const T &b) const { \
|
METAL_FUNC R operator()(device const T &a, device const T &b) const { \
|
||||||
return op; \
|
return op; \
|
||||||
} \
|
} \
|
||||||
\
|
\
|
||||||
METAL_FUNC T operator()(T a, T b) { \
|
METAL_FUNC R operator()(T a, T b) { \
|
||||||
return op; \
|
return op; \
|
||||||
} \
|
} \
|
||||||
} \
|
} \
|
||||||
@ -45,10 +45,13 @@ impl_reduction_op(Sum, a + b, 0);
|
|||||||
impl_reduction_op(Mul, a * b, 1);
|
impl_reduction_op(Mul, a * b, 1);
|
||||||
impl_reduction_op(Min, a < b ? a : b, numeric_limits<T>::max());
|
impl_reduction_op(Min, a < b ? a : b, numeric_limits<T>::max());
|
||||||
impl_reduction_op(Max, a > b ? a : b, numeric_limits<T>::min());
|
impl_reduction_op(Max, a > b ? a : b, numeric_limits<T>::min());
|
||||||
|
impl_reduction_op(ArgMin, a < b, numeric_limits<T>::max());
|
||||||
|
impl_reduction_op(ArgMax, a > b, numeric_limits<T>::min());
|
||||||
#undef impl_reduction_op
|
#undef impl_reduction_op
|
||||||
|
|
||||||
static constant constexpr int THREADGROUP_SIZE = 2048;
|
static constant constexpr int THREADGROUP_SIZE = 2048;
|
||||||
|
|
||||||
|
// Load strided elements from global memory into shared memory.
|
||||||
template<typename T, typename ReductionOp, uint BLOCKSIZE>
|
template<typename T, typename ReductionOp, uint BLOCKSIZE>
|
||||||
METAL_FUNC void load_from_global(
|
METAL_FUNC void load_from_global(
|
||||||
constant size_t &num_dims,
|
constant size_t &num_dims,
|
||||||
@ -74,6 +77,40 @@ METAL_FUNC void load_from_global(
|
|||||||
threadgroup_barrier(mem_flags::mem_none);
|
threadgroup_barrier(mem_flags::mem_none);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Load strided elements from global memory into shared memory with indices.
|
||||||
|
template<typename T, typename ArgReductionOp, uint BLOCKSIZE>
|
||||||
|
METAL_FUNC void load_from_global(
|
||||||
|
constant size_t &num_dims,
|
||||||
|
constant size_t *dims,
|
||||||
|
constant size_t *strides,
|
||||||
|
constant size_t &el_to_sum_per_block,
|
||||||
|
device const T *src,
|
||||||
|
threadgroup T shared[BLOCKSIZE],
|
||||||
|
threadgroup uint shared_indices[BLOCKSIZE],
|
||||||
|
uint tid [[ thread_index_in_threadgroup ]],
|
||||||
|
uint dst_id [[ threadgroup_position_in_grid ]],
|
||||||
|
uint block_dim [[ threads_per_threadgroup ]]
|
||||||
|
) {
|
||||||
|
ArgReductionOp op;
|
||||||
|
bool notset = true;
|
||||||
|
|
||||||
|
size_t start_idx = dst_id * el_to_sum_per_block;
|
||||||
|
size_t stop_idx = start_idx + el_to_sum_per_block;
|
||||||
|
size_t idx = start_idx + tid;
|
||||||
|
while (idx < stop_idx) {
|
||||||
|
size_t strided_i = get_strided_index(idx, num_dims, dims, strides);
|
||||||
|
if (notset || op(src[strided_i], shared[tid])) {
|
||||||
|
shared[tid] = src[strided_i];
|
||||||
|
// Assume that the reduction takes place over the last dimension which is contiguous.
|
||||||
|
shared_indices[tid] = idx % dims[num_dims - 1];
|
||||||
|
notset = false;
|
||||||
|
}
|
||||||
|
idx += block_dim;
|
||||||
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_none);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load contiguous elements from global memory into shared memory.
|
||||||
template<typename T, typename ReductionOp, uint BLOCKSIZE>
|
template<typename T, typename ReductionOp, uint BLOCKSIZE>
|
||||||
METAL_FUNC void load_from_global(
|
METAL_FUNC void load_from_global(
|
||||||
constant size_t &num_dims,
|
constant size_t &num_dims,
|
||||||
@ -97,6 +134,45 @@ METAL_FUNC void load_from_global(
|
|||||||
threadgroup_barrier(mem_flags::mem_none);
|
threadgroup_barrier(mem_flags::mem_none);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Load contiguous elements from global memory into shared memory with indices.
|
||||||
|
template<typename T, typename ArgReductionOp, uint BLOCKSIZE>
|
||||||
|
METAL_FUNC void load_from_global(
|
||||||
|
constant size_t &num_dims,
|
||||||
|
constant size_t *dims,
|
||||||
|
constant size_t &el_to_sum_per_block,
|
||||||
|
device const T *src,
|
||||||
|
threadgroup T shared[BLOCKSIZE],
|
||||||
|
threadgroup uint shared_indices[BLOCKSIZE],
|
||||||
|
uint tid [[ thread_index_in_threadgroup ]],
|
||||||
|
uint dst_id [[ threadgroup_position_in_grid ]],
|
||||||
|
uint block_dim [[ threads_per_threadgroup ]]
|
||||||
|
) {
|
||||||
|
ArgReductionOp op;
|
||||||
|
bool notset = true;
|
||||||
|
|
||||||
|
size_t start_idx = dst_id * el_to_sum_per_block;
|
||||||
|
size_t stop_idx = start_idx + el_to_sum_per_block;
|
||||||
|
size_t idx = start_idx + tid;
|
||||||
|
while (idx < stop_idx) {
|
||||||
|
if (notset || op(src[idx], shared[tid])) {
|
||||||
|
shared[tid] = src[idx];
|
||||||
|
// Assume that the reduction takes place over the last dimension which is contiguous.
|
||||||
|
shared_indices[tid] = idx % dims[num_dims - 1];
|
||||||
|
notset = false;
|
||||||
|
}
|
||||||
|
idx += block_dim;
|
||||||
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_none);
|
||||||
|
}
|
||||||
|
|
||||||
|
#define reduce_threadgroup(SIZE) \
|
||||||
|
if (BLOCKSIZE >= SIZE) { \
|
||||||
|
if (block_dim >= SIZE) { \
|
||||||
|
shared[tid] = op(shared[tid], shared[tid + SIZE / 2]); \
|
||||||
|
threadgroup_barrier(mem_flags::mem_none); \
|
||||||
|
} \
|
||||||
|
}
|
||||||
|
|
||||||
template<typename T, typename ReductionOp, uint BLOCKSIZE>
|
template<typename T, typename ReductionOp, uint BLOCKSIZE>
|
||||||
METAL_FUNC void threadgroup_reduce(
|
METAL_FUNC void threadgroup_reduce(
|
||||||
threadgroup T shared[BLOCKSIZE],
|
threadgroup T shared[BLOCKSIZE],
|
||||||
@ -104,37 +180,50 @@ METAL_FUNC void threadgroup_reduce(
|
|||||||
uint block_dim [[ threads_per_threadgroup ]]
|
uint block_dim [[ threads_per_threadgroup ]]
|
||||||
) {
|
) {
|
||||||
ReductionOp op;
|
ReductionOp op;
|
||||||
if (BLOCKSIZE >= 64) {
|
reduce_threadgroup(64);
|
||||||
if (block_dim >= 64) {
|
reduce_threadgroup(32);
|
||||||
shared[tid] = op(shared[tid], shared[tid + 32]);
|
reduce_threadgroup(16);
|
||||||
}
|
reduce_threadgroup(8);
|
||||||
}
|
reduce_threadgroup(4);
|
||||||
if (BLOCKSIZE >= 32) {
|
reduce_threadgroup(2);
|
||||||
if (block_dim >= 32) {
|
|
||||||
shared[tid] = op(shared[tid], shared[tid + 16]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (BLOCKSIZE >= 16) {
|
|
||||||
if (block_dim >= 16) {
|
|
||||||
shared[tid] = op(shared[tid], shared[tid + 8]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (BLOCKSIZE >= 8) {
|
|
||||||
if (block_dim >= 8) {
|
|
||||||
shared[tid] = op(shared[tid], shared[tid + 4]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (BLOCKSIZE >= 4) {
|
|
||||||
if (block_dim >= 4) {
|
|
||||||
shared[tid] = op(shared[tid], shared[tid + 2]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (BLOCKSIZE >= 2) {
|
|
||||||
if (block_dim >= 2) {
|
|
||||||
shared[tid] = op(shared[tid], shared[tid + 1]);
|
|
||||||
}
|
}
|
||||||
|
#undef reduce_threadgroup
|
||||||
|
|
||||||
|
#define arg_reduce_threadgroup(SIZE) \
|
||||||
|
if (BLOCKSIZE >= SIZE) { \
|
||||||
|
if (block_dim >= SIZE && \
|
||||||
|
op(shared[tid], shared[tid + SIZE / 2]) \
|
||||||
|
) { \
|
||||||
|
shared_indices[tid] = shared_indices[tid + SIZE / 2]; \
|
||||||
|
shared[tid] = shared[tid + SIZE / 2]; \
|
||||||
|
threadgroup_barrier(mem_flags::mem_none); \
|
||||||
|
} \
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<typename T, typename ArgReductionOp, uint BLOCKSIZE>
|
||||||
|
METAL_FUNC void threadgroup_reduce(
|
||||||
|
threadgroup T shared[BLOCKSIZE],
|
||||||
|
threadgroup uint shared_indices[BLOCKSIZE],
|
||||||
|
uint tid [[thread_index_in_threadgroup]],
|
||||||
|
uint block_dim [[ threads_per_threadgroup ]]
|
||||||
|
) {
|
||||||
|
ArgReductionOp op;
|
||||||
|
arg_reduce_threadgroup(64);
|
||||||
|
arg_reduce_threadgroup(32);
|
||||||
|
arg_reduce_threadgroup(16);
|
||||||
|
arg_reduce_threadgroup(8);
|
||||||
|
arg_reduce_threadgroup(4);
|
||||||
|
arg_reduce_threadgroup(2);
|
||||||
}
|
}
|
||||||
|
#undef arg_reduce_threadgroup
|
||||||
|
|
||||||
|
#define reduce_block(SIZE) \
|
||||||
|
if (BLOCKSIZE >= SIZE) { \
|
||||||
|
if (tid < SIZE / 2 && block_dim >= SIZE) { \
|
||||||
|
shared[tid] = op(shared[tid], shared[tid + SIZE / 2]); \
|
||||||
|
threadgroup_barrier(mem_flags::mem_none); \
|
||||||
|
} \
|
||||||
|
} \
|
||||||
|
|
||||||
// Inspired by "Optimizing Parallel Reduction in CUDA" by Mark Harris
|
// Inspired by "Optimizing Parallel Reduction in CUDA" by Mark Harris
|
||||||
template<
|
template<
|
||||||
@ -186,42 +275,20 @@ METAL_FUNC void block_reduce(
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (BLOCKSIZE >= 1024) {
|
reduce_block(1024);
|
||||||
if (tid < 512 && block_dim >= 1024) {
|
reduce_block(512);
|
||||||
shared[tid] = op(shared[tid], shared[tid + 512]);
|
reduce_block(256);
|
||||||
threadgroup_barrier(mem_flags::mem_none);
|
reduce_block(128);
|
||||||
}
|
|
||||||
}
|
|
||||||
if (BLOCKSIZE >= 512) {
|
|
||||||
if (tid < 256 && block_dim >= 512) {
|
|
||||||
shared[tid] = op(shared[tid], shared[tid + 256]);
|
|
||||||
threadgroup_barrier(mem_flags::mem_none);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (BLOCKSIZE >= 256) {
|
|
||||||
if (tid < 128 && block_dim >= 256) {
|
|
||||||
shared[tid] = op(shared[tid], shared[tid + 128]);
|
|
||||||
threadgroup_barrier(mem_flags::mem_none);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (BLOCKSIZE >= 128) {
|
|
||||||
if (tid < 64 && block_dim >= 128) {
|
|
||||||
shared[tid] = op(shared[tid], shared[tid + 64]);
|
|
||||||
threadgroup_barrier(mem_flags::mem_none);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (tid < 32) {
|
if (tid < 32) {
|
||||||
threadgroup_reduce<T, ReductionOp, BLOCKSIZE>(shared, tid, block_dim);
|
threadgroup_reduce<T, ReductionOp, BLOCKSIZE>(shared, tid, block_dim);
|
||||||
threadgroup_barrier(mem_flags::mem_none);
|
threadgroup_barrier(mem_flags::mem_none);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (tid == 0) {
|
if (tid == 0) {
|
||||||
dst[dst_id] = shared[tid];
|
dst[dst_id] = shared[tid];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#undef reduce_block
|
||||||
#define MAX(x, y) ((x) > (y) ? (x) : (y))
|
|
||||||
#define MIN(x, y) ((x) < (y) ? (x) : (y))
|
|
||||||
|
|
||||||
static constant constexpr int BLOCKSIZE = 2048;
|
static constant constexpr int BLOCKSIZE = 2048;
|
||||||
|
|
||||||
@ -283,7 +350,86 @@ kernel void NAME##_strided( \
|
|||||||
block_dim); \
|
block_dim); \
|
||||||
} \
|
} \
|
||||||
|
|
||||||
#define ARGMIN(NAME, T, MAXVALUE) \
|
#define arg_reduce_block(SIZE) \
|
||||||
|
if (BLOCKSIZE >= SIZE) { \
|
||||||
|
if (tid < SIZE / 2 \
|
||||||
|
&& block_dim >= SIZE \
|
||||||
|
&& arg_op(shared[tid], shared[tid + SIZE / 2]) \
|
||||||
|
) { \
|
||||||
|
shared_indices[tid] = shared_indices[tid + SIZE / 2]; \
|
||||||
|
shared[tid] = shared[tid + SIZE / 2]; \
|
||||||
|
threadgroup_barrier(mem_flags::mem_none); \
|
||||||
|
} \
|
||||||
|
} \
|
||||||
|
|
||||||
|
template<
|
||||||
|
typename T,
|
||||||
|
typename ArgReductionOp,
|
||||||
|
uint BLOCKSIZE,
|
||||||
|
bool STRIDED
|
||||||
|
>
|
||||||
|
METAL_FUNC void arg_block_reduce(
|
||||||
|
constant size_t &num_dims,
|
||||||
|
constant size_t *dims,
|
||||||
|
constant size_t *strides,
|
||||||
|
constant size_t &el_to_sum_per_block,
|
||||||
|
device const T *src,
|
||||||
|
device uint *dst,
|
||||||
|
threadgroup T shared[BLOCKSIZE],
|
||||||
|
threadgroup uint shared_indices[BLOCKSIZE],
|
||||||
|
uint id [[ thread_position_in_grid ]],
|
||||||
|
uint tid [[ thread_index_in_threadgroup ]],
|
||||||
|
uint dst_id [[ threadgroup_position_in_grid ]],
|
||||||
|
uint block_dim [[ threads_per_threadgroup ]]
|
||||||
|
) {
|
||||||
|
ArgReductionOp arg_op;
|
||||||
|
|
||||||
|
shared[tid] = ArgReductionOp::init;
|
||||||
|
shared_indices[tid] = numeric_limits<uint>::max();
|
||||||
|
|
||||||
|
if (STRIDED) {
|
||||||
|
load_from_global<T, ArgReductionOp, BLOCKSIZE>(
|
||||||
|
num_dims,
|
||||||
|
dims,
|
||||||
|
strides,
|
||||||
|
el_to_sum_per_block,
|
||||||
|
src,
|
||||||
|
shared,
|
||||||
|
shared_indices,
|
||||||
|
tid,
|
||||||
|
dst_id,
|
||||||
|
block_dim
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
load_from_global<T, ArgReductionOp, BLOCKSIZE>(
|
||||||
|
num_dims,
|
||||||
|
dims,
|
||||||
|
el_to_sum_per_block,
|
||||||
|
src,
|
||||||
|
shared,
|
||||||
|
shared_indices,
|
||||||
|
tid,
|
||||||
|
dst_id,
|
||||||
|
block_dim
|
||||||
|
);
|
||||||
|
}
|
||||||
|
arg_reduce_block(1024);
|
||||||
|
arg_reduce_block(512);
|
||||||
|
arg_reduce_block(256);
|
||||||
|
arg_reduce_block(128);
|
||||||
|
|
||||||
|
if (tid < 32) {
|
||||||
|
threadgroup_reduce<T, ArgReductionOp, BLOCKSIZE>(shared, shared_indices, tid, block_dim);
|
||||||
|
threadgroup_barrier(mem_flags::mem_none);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (tid == 0) {
|
||||||
|
dst[dst_id] = shared_indices[0];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#undef arg_reduce_block
|
||||||
|
|
||||||
|
#define ARG_REDUCE(OP, NAME, T) \
|
||||||
kernel void NAME( \
|
kernel void NAME( \
|
||||||
constant size_t &num_dims, \
|
constant size_t &num_dims, \
|
||||||
constant size_t *dims, \
|
constant size_t *dims, \
|
||||||
@ -296,55 +442,23 @@ kernel void NAME( \
|
|||||||
uint dst_id [[ threadgroup_position_in_grid ]], \
|
uint dst_id [[ threadgroup_position_in_grid ]], \
|
||||||
uint block_dim [[ threads_per_threadgroup ]] \
|
uint block_dim [[ threads_per_threadgroup ]] \
|
||||||
) { \
|
) { \
|
||||||
\
|
threadgroup T shared[BLOCKSIZE]; \
|
||||||
threadgroup T shared_memory[THREADGROUP_SIZE]; \
|
threadgroup uint shared_indices[BLOCKSIZE]; \
|
||||||
threadgroup uint shared_indices[THREADGROUP_SIZE]; \
|
arg_block_reduce<T, OP<T, bool>, BLOCKSIZE, false>( \
|
||||||
\
|
num_dims, \
|
||||||
shared_memory[tid] = MAXVALUE; \
|
dims, \
|
||||||
shared_indices[tid] = 0xFFFFFFFF; \
|
strides, \
|
||||||
bool notset = true; \
|
el_to_sum_per_block, \
|
||||||
/* \
|
src, \
|
||||||
// Elements summed in this block range from dst_id * el_to_sum_per_block \
|
dst, \
|
||||||
// to (dst_id + 1) * el_to_sum_per_block. \
|
shared, \
|
||||||
*/ \
|
shared_indices, \
|
||||||
size_t start_idx = dst_id * el_to_sum_per_block; \
|
id, \
|
||||||
size_t stop_idx = start_idx + el_to_sum_per_block; \
|
tid, \
|
||||||
size_t idx = start_idx + tid; \
|
dst_id, \
|
||||||
while (idx < stop_idx) { \
|
block_dim); \
|
||||||
/* \
|
|
||||||
// TODO: Fast version for the contiguous case. \
|
|
||||||
*/ \
|
|
||||||
size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
|
|
||||||
if (notset || src[strided_i] < shared_memory[tid]) { \
|
|
||||||
shared_memory[tid] = src[strided_i]; \
|
|
||||||
/* Assume that the reduction takes place over the last dimension which is contiguous. */ \
|
|
||||||
shared_indices[tid] = idx % dims[num_dims - 1]; \
|
|
||||||
notset = false; \
|
|
||||||
} \
|
} \
|
||||||
idx += block_dim; \
|
kernel void NAME##_strided( \
|
||||||
} \
|
|
||||||
\
|
|
||||||
threadgroup_barrier(mem_flags::mem_none); \
|
|
||||||
\
|
|
||||||
/* \
|
|
||||||
// reduction in shared memory \
|
|
||||||
*/ \
|
|
||||||
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
|
||||||
if (tid < s && shared_memory[tid + s] < shared_memory[tid]) { \
|
|
||||||
shared_indices[tid] = shared_indices[tid + s]; \
|
|
||||||
shared_memory[tid] = shared_memory[tid + s]; \
|
|
||||||
} \
|
|
||||||
threadgroup_barrier(mem_flags::mem_none); \
|
|
||||||
} \
|
|
||||||
\
|
|
||||||
if (tid == 0){ \
|
|
||||||
dst[dst_id] = shared_indices[0]; \
|
|
||||||
} \
|
|
||||||
} \
|
|
||||||
|
|
||||||
|
|
||||||
#define ARGMAX(NAME, T, MINVALUE) \
|
|
||||||
kernel void NAME( \
|
|
||||||
constant size_t &num_dims, \
|
constant size_t &num_dims, \
|
||||||
constant size_t *dims, \
|
constant size_t *dims, \
|
||||||
constant size_t *strides, \
|
constant size_t *strides, \
|
||||||
@ -356,50 +470,26 @@ kernel void NAME( \
|
|||||||
uint dst_id [[ threadgroup_position_in_grid ]], \
|
uint dst_id [[ threadgroup_position_in_grid ]], \
|
||||||
uint block_dim [[ threads_per_threadgroup ]] \
|
uint block_dim [[ threads_per_threadgroup ]] \
|
||||||
) { \
|
) { \
|
||||||
\
|
threadgroup T shared[BLOCKSIZE]; \
|
||||||
threadgroup T shared_memory[THREADGROUP_SIZE]; \
|
threadgroup uint shared_indices[BLOCKSIZE]; \
|
||||||
threadgroup uint shared_indices[THREADGROUP_SIZE]; \
|
arg_block_reduce<T, OP<T, bool>, BLOCKSIZE, true>( \
|
||||||
\
|
num_dims, \
|
||||||
shared_memory[tid] = MINVALUE; \
|
dims, \
|
||||||
shared_indices[tid] = 0xFFFFFFFF; \
|
strides, \
|
||||||
/* \
|
el_to_sum_per_block, \
|
||||||
// Elements summed in this block range from dst_id * el_to_sum_per_block \
|
src, \
|
||||||
// to (dst_id + 1) * el_to_sum_per_block. \
|
dst, \
|
||||||
*/ \
|
shared, \
|
||||||
size_t start_idx = dst_id * el_to_sum_per_block; \
|
shared_indices, \
|
||||||
size_t stop_idx = start_idx + el_to_sum_per_block; \
|
id, \
|
||||||
size_t idx = start_idx + tid; \
|
tid, \
|
||||||
bool notset = true; \
|
dst_id, \
|
||||||
while (idx < stop_idx) { \
|
block_dim); \
|
||||||
/* \
|
}
|
||||||
// TODO: Fast version for the contiguous case. \
|
|
||||||
*/ \
|
|
||||||
size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
|
#define MAX(x, y) ((x) > (y) ? (x) : (y))
|
||||||
if (notset || shared_memory[tid] < src[strided_i]) { \
|
#define MIN(x, y) ((x) < (y) ? (x) : (y))
|
||||||
shared_memory[tid] = src[strided_i]; \
|
|
||||||
shared_indices[tid] = idx % dims[num_dims - 1]; \
|
|
||||||
notset = false; \
|
|
||||||
} \
|
|
||||||
idx += block_dim; \
|
|
||||||
} \
|
|
||||||
\
|
|
||||||
threadgroup_barrier(mem_flags::mem_none); \
|
|
||||||
\
|
|
||||||
/* \
|
|
||||||
// reduction in shared memory \
|
|
||||||
*/ \
|
|
||||||
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
|
||||||
if (tid < s && shared_memory[tid + s] > shared_memory[tid]) { \
|
|
||||||
shared_indices[tid] = shared_indices[tid + s]; \
|
|
||||||
shared_memory[tid] = shared_memory[tid + s]; \
|
|
||||||
} \
|
|
||||||
threadgroup_barrier(mem_flags::mem_none); \
|
|
||||||
} \
|
|
||||||
\
|
|
||||||
if (tid == 0){ \
|
|
||||||
dst[dst_id] = shared_indices[0]; \
|
|
||||||
} \
|
|
||||||
} \
|
|
||||||
|
|
||||||
#define SOFTMAX(NAME, T) \
|
#define SOFTMAX(NAME, T) \
|
||||||
kernel void NAME( \
|
kernel void NAME( \
|
||||||
@ -472,26 +562,31 @@ REDUCE(Sum, fast_sum_f32, float)
|
|||||||
REDUCE(Sum, fast_sum_u32, uint)
|
REDUCE(Sum, fast_sum_u32, uint)
|
||||||
REDUCE(Sum, fast_sum_f16, half)
|
REDUCE(Sum, fast_sum_f16, half)
|
||||||
REDUCE(Sum, fast_sum_u8, uint8_t)
|
REDUCE(Sum, fast_sum_u8, uint8_t)
|
||||||
|
|
||||||
REDUCE(Mul, fast_mul_f32, float)
|
REDUCE(Mul, fast_mul_f32, float)
|
||||||
REDUCE(Mul, fast_mul_u32, uint)
|
REDUCE(Mul, fast_mul_u32, uint)
|
||||||
REDUCE(Mul, fast_mul_f16, half)
|
REDUCE(Mul, fast_mul_f16, half)
|
||||||
|
REDUCE(Mul, fast_mul_u8, uint8_t)
|
||||||
|
|
||||||
REDUCE(Max, fast_max_f32, float)
|
REDUCE(Max, fast_max_f32, float)
|
||||||
REDUCE(Max, fast_max_u32, uint)
|
REDUCE(Max, fast_max_u32, uint)
|
||||||
REDUCE(Max, fast_max_f16, half)
|
REDUCE(Max, fast_max_f16, half)
|
||||||
REDUCE(Max, fast_max_u8, uint8_t)
|
REDUCE(Max, fast_max_u8, uint8_t)
|
||||||
|
|
||||||
REDUCE(Min, fast_min_f32, float)
|
REDUCE(Min, fast_min_f32, float)
|
||||||
REDUCE(Min, fast_min_u32, uint)
|
REDUCE(Min, fast_min_u32, uint)
|
||||||
REDUCE(Min, fast_min_f16, half)
|
REDUCE(Min, fast_min_f16, half)
|
||||||
REDUCE(Min, fast_min_u8, uint8_t)
|
REDUCE(Min, fast_min_u8, uint8_t)
|
||||||
|
|
||||||
ARGMIN(fast_argmin_f32_strided, float, HUGE_VALF)
|
ARG_REDUCE(ArgMin, fast_argmin_f32, float)
|
||||||
ARGMIN(fast_argmin_f16_strided, half, HUGE_VALH)
|
ARG_REDUCE(ArgMin, fast_argmin_f16, half)
|
||||||
ARGMIN(fast_argmin_u32_strided, uint, 0xFFFFFFFF)
|
ARG_REDUCE(ArgMin, fast_argmin_u32, uint)
|
||||||
ARGMIN(fast_argmin_u8_strided, uint8_t, 0xFF)
|
ARG_REDUCE(ArgMin, fast_argmin_u8, uint8_t)
|
||||||
ARGMAX(fast_argmax_f32_strided, float, -HUGE_VALF)
|
|
||||||
ARGMAX(fast_argmax_f16_strided, half, -HUGE_VALH)
|
ARG_REDUCE(ArgMax, fast_argmax_f32, float)
|
||||||
ARGMAX(fast_argmax_u32_strided, uint, 0)
|
ARG_REDUCE(ArgMax, fast_argmax_f16, half)
|
||||||
ARGMAX(fast_argmax_u8_strided, uint8_t, 0)
|
ARG_REDUCE(ArgMax, fast_argmax_u32, uint)
|
||||||
|
ARG_REDUCE(ArgMax, fast_argmax_u8, uint8_t)
|
||||||
|
|
||||||
SOFTMAX(softmax_f32, float)
|
SOFTMAX(softmax_f32, float)
|
||||||
SOFTMAX(softmax_f16, half)
|
SOFTMAX(softmax_f16, half)
|
||||||
@ -502,8 +597,9 @@ REDUCE(Mul, fast_mul_i64, int64_t)
|
|||||||
REDUCE(Min, fast_min_i64, int64_t)
|
REDUCE(Min, fast_min_i64, int64_t)
|
||||||
REDUCE(Max, fast_max_i64, int64_t)
|
REDUCE(Max, fast_max_i64, int64_t)
|
||||||
|
|
||||||
ARGMIN(fast_argmin_i64_strided, int64_t, INT_MAX)
|
ARG_REDUCE(ArgMin, fast_argmin_i64, int64_t)
|
||||||
ARGMAX(fast_argmax_i64_strided, int64_t, INT_MIN)
|
ARG_REDUCE(ArgMax, fast_argmax_i64, int64_t)
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if __METAL_VERSION__ >= 310
|
#if __METAL_VERSION__ >= 310
|
||||||
@ -512,7 +608,8 @@ REDUCE(Mul, fast_mul_bf16, bfloat)
|
|||||||
REDUCE(Max, fast_max_bf16, bfloat)
|
REDUCE(Max, fast_max_bf16, bfloat)
|
||||||
REDUCE(Min, fast_min_bf16, bfloat)
|
REDUCE(Min, fast_min_bf16, bfloat)
|
||||||
|
|
||||||
ARGMIN(fast_argmin_bf16, bfloat, HUGE_VALBF)
|
ARG_REDUCE(ArgMin, fast_argmin_bf16, bfloat)
|
||||||
ARGMAX(fast_argmax_bf16, bfloat, -HUGE_VALBF)
|
ARG_REDUCE(ArgMax, fast_argmax_bf16, bfloat)
|
||||||
|
|
||||||
SOFTMAX(softmax_bf16, bfloat)
|
SOFTMAX(softmax_bf16, bfloat)
|
||||||
#endif
|
#endif
|
||||||
|
Reference in New Issue
Block a user