Cuda kernels for fast min/max reductions (#203)

* Add the min/max cuda kernels.

* Better integration of the cuda kernels.
This commit is contained in:
Laurent Mazare
2023-07-19 19:12:27 +02:00
committed by GitHub
parent 001f9a59ce
commit 536c5e702e
3 changed files with 130 additions and 22 deletions

View File

@ -2,6 +2,7 @@
// https://people.maths.ox.ac.uk/gilesm/cuda/prac4/reduction.pdf
#include "cuda_utils.cuh"
#include<stdint.h>
#include<cmath>
const int BLOCK_SIZE = 1024;
@ -27,7 +28,7 @@ __device__ void fast_sum(
size_t tid = threadIdx.x;
size_t dst_id = blockIdx.x;
shr[tid] = 0.0;
shr[tid] = 0;
// Elements summed in this block range from dst_id * el_to_sum_per_block
// to (dst_id + 1) * el_to_sum_per_block.
size_t start_idx = dst_id * el_to_sum_per_block;
@ -49,11 +50,113 @@ __device__ void fast_sum(
if (tid < s) shr[tid] += shr[tid + s];
}
if (tid == 0) atomicAdd(dst + dst_id, shr[0]);
if (tid == 0) dst[dst_id] = shr[0];
}
#define FAST_SUM_OP(TYPENAME, FN_NAME) \
extern "C" __global__ void FN_NAME( \
template <typename T>
__device__ void fast_max(
const size_t src_numel,
const size_t el_to_sum_per_block,
const size_t num_dims,
const size_t *info,
const T *src,
T *dst
) {
const size_t *dims = info;
const size_t *strides = info + num_dims;
__shared__ T shr[BLOCK_SIZE];
size_t tid = threadIdx.x;
size_t dst_id = blockIdx.x;
shr[tid] = -INFINITY;
// Elements summed in this block range from dst_id * el_to_sum_per_block
// to (dst_id + 1) * el_to_sum_per_block.
size_t start_idx = dst_id * el_to_sum_per_block;
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel);
size_t idx = start_idx + tid;
while (idx < stop_idx) {
// TODO: Fast version for the contiguous case.
size_t strided_i = get_strided_index(idx, num_dims, dims, strides);
shr[tid] = maxg(shr[tid], src[strided_i]);
idx += blockDim.x;
}
// Parallel reduction, see the slides:
// https://www.olcf.ornl.gov/wp-content/uploads/2019/12/05_Atomics_Reductions_Warp_Shuffle.pdf
// https://stackoverflow.com/questions/66078814/is-cuda-atomicadd-operation-faster-than-launch-another-kernel-when-we-do-reduce
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
__syncthreads();
if (tid < s) shr[tid] = maxg(shr[tid], shr[tid + s]);
}
if (tid == 0) dst[dst_id] = shr[0];
}
template <typename T>
__device__ void fast_min(
const size_t src_numel,
const size_t el_to_sum_per_block,
const size_t num_dims,
const size_t *info,
const T *src,
T *dst
) {
const size_t *dims = info;
const size_t *strides = info + num_dims;
__shared__ T shr[BLOCK_SIZE];
size_t tid = threadIdx.x;
size_t dst_id = blockIdx.x;
shr[tid] = INFINITY;
// Elements summed in this block range from dst_id * el_to_sum_per_block
// to (dst_id + 1) * el_to_sum_per_block.
size_t start_idx = dst_id * el_to_sum_per_block;
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel);
size_t idx = start_idx + tid;
while (idx < stop_idx) {
// TODO: Fast version for the contiguous case.
size_t strided_i = get_strided_index(idx, num_dims, dims, strides);
shr[tid] = ming(shr[tid], src[strided_i]);
idx += blockDim.x;
}
// Parallel reduction, see the slides:
// https://www.olcf.ornl.gov/wp-content/uploads/2019/12/05_Atomics_Reductions_Warp_Shuffle.pdf
// https://stackoverflow.com/questions/66078814/is-cuda-atomicadd-operation-faster-than-launch-another-kernel-when-we-do-reduce
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
__syncthreads();
if (tid < s) shr[tid] = ming(shr[tid], shr[tid + s]);
}
if (tid == 0) dst[dst_id] = shr[0];
}
#define FAST_OP(TYPENAME, MIN_NAME, MAX_NAME, SUM_NAME) \
extern "C" __global__ void MIN_NAME( \
const size_t src_numel, \
const size_t el_to_sum_per_block, \
const size_t num_dims, \
const size_t *info, \
const TYPENAME *src, \
TYPENAME *dst \
) { \
fast_min(src_numel, el_to_sum_per_block, num_dims, info, src, dst); \
} \
extern "C" __global__ void MAX_NAME( \
const size_t src_numel, \
const size_t el_to_sum_per_block, \
const size_t num_dims, \
const size_t *info, \
const TYPENAME *src, \
TYPENAME *dst \
) { \
fast_max(src_numel, el_to_sum_per_block, num_dims, info, src, dst); \
} \
extern "C" __global__ void SUM_NAME( \
const size_t src_numel, \
const size_t el_to_sum_per_block, \
const size_t num_dims, \
@ -106,18 +209,18 @@ extern "C" __global__ void FN_NAME( \
#if __CUDA_ARCH__ >= 800
SUM_OP(__nv_bfloat16, sum_bf16)
FAST_SUM_OP(__nv_bfloat16, fast_sum_bf16)
FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_sum_bf16)
#endif
#if __CUDA_ARCH__ >= 530
SUM_OP(__half, sum_f16)
FAST_SUM_OP(__half, fast_sum_f16)
FAST_OP(__half, fast_min_f16, fast_max_f16, fast_sum_f16)
#endif
SUM_OP(float, sum_f32)
SUM_OP(double, sum_f64)
SUM_OP(uint32_t, sum_u32)
FAST_SUM_OP(float, fast_sum_f32)
FAST_SUM_OP(double, fast_sum_f64)
FAST_SUM_OP(uint32_t, fast_sum_u32)
FAST_OP(float, fast_min_f32, fast_max_f32, fast_sum_f32)
FAST_OP(double, fast_min_f64, fast_max_f64, fast_sum_f64)
FAST_OP(uint32_t, fast_min_u32, fast_max_u32, fast_sum_u32)