mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Cuda support for the mnist training. (#277)
* Cuda support for the mnist training. * min/max fix + testing. * Add the argmin/argmax tests. * More cuda support for argmin/argmax. * Cuda kernels for argmin and argmax.
This commit is contained in:
@ -144,7 +144,8 @@ __device__ __forceinline__ double copysigng(double a, double b) { return copysig
|
||||
|
||||
__device__ __forceinline__ uint32_t ming(uint32_t a, uint32_t b) { return min(a, b); }
|
||||
__device__ __forceinline__ uint32_t maxg(uint32_t a, uint32_t b) { return max(a, b); }
|
||||
|
||||
__device__ __forceinline__ uint8_t ming(uint8_t a, uint8_t b) { return min(a, b); }
|
||||
__device__ __forceinline__ uint8_t maxg(uint8_t a, uint8_t b) { return max(a, b); }
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
__device__ __forceinline__ __half powg(__half a, __half b) { return __float2half(powf(__half2float(a), __half2float(b))); }
|
||||
__device__ __forceinline__ bool isnang(__half a) { return __hisnan(a); }
|
||||
|
@ -125,7 +125,116 @@ fast_min(const size_t src_numel, const size_t el_to_sum_per_block,
|
||||
dst[dst_id] = shr[0];
|
||||
}
|
||||
|
||||
#define FAST_OP(TYPENAME, MIN_NAME, MAX_NAME, SUM_NAME) \
|
||||
template <typename T>
|
||||
__device__ void
|
||||
fast_argmin(const size_t src_numel, const size_t el_to_sum_per_block,
|
||||
const size_t num_dims, const size_t *info, const T *src, uint32_t *dst) {
|
||||
const size_t *dims = info;
|
||||
const size_t *strides = info + num_dims;
|
||||
|
||||
__shared__ T shr[BLOCK_SIZE];
|
||||
__shared__ uint32_t shr_index[BLOCK_SIZE];
|
||||
size_t tid = threadIdx.x;
|
||||
size_t dst_id = blockIdx.x;
|
||||
|
||||
// Not sure how that works on uint32_t and uint8_t but it seems to do ok.
|
||||
shr[tid] = INFINITY;
|
||||
shr_index[tid] = 0xFFFFFFFF;
|
||||
bool not_set = true;
|
||||
// Elements summed in this block range from dst_id * el_to_sum_per_block
|
||||
// to (dst_id + 1) * el_to_sum_per_block.
|
||||
size_t start_idx = dst_id * el_to_sum_per_block;
|
||||
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel);
|
||||
size_t idx = start_idx + tid;
|
||||
|
||||
while (idx < stop_idx) {
|
||||
// TODO: Fast version for the contiguous case.
|
||||
size_t strided_i = get_strided_index(idx, num_dims, dims, strides);
|
||||
if (not_set || src[strided_i] < shr[tid]) {
|
||||
shr[tid] = src[strided_i];
|
||||
// Assume that the reduction takes place over the last dimension which is contiguous.
|
||||
shr_index[tid] = idx % dims[num_dims - 1];
|
||||
not_set = false;
|
||||
}
|
||||
idx += blockDim.x;
|
||||
}
|
||||
|
||||
// Parallel reduction, see the slides:
|
||||
// https://www.olcf.ornl.gov/wp-content/uploads/2019/12/05_Atomics_Reductions_Warp_Shuffle.pdf
|
||||
// https://stackoverflow.com/questions/66078814/is-cuda-atomicadd-operation-faster-than-launch-another-kernel-when-we-do-reduce
|
||||
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
|
||||
__syncthreads();
|
||||
if (tid < s && shr[tid + s] < shr[tid]) {
|
||||
shr[tid] = shr[tid + s];
|
||||
shr_index[tid] = shr_index[tid + s];
|
||||
}
|
||||
}
|
||||
|
||||
if (tid == 0)
|
||||
dst[dst_id] = shr_index[0];
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ void
|
||||
fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block,
|
||||
const size_t num_dims, const size_t *info, const T *src, uint32_t *dst) {
|
||||
const size_t *dims = info;
|
||||
const size_t *strides = info + num_dims;
|
||||
|
||||
__shared__ T shr[BLOCK_SIZE];
|
||||
__shared__ uint32_t shr_index[BLOCK_SIZE];
|
||||
size_t tid = threadIdx.x;
|
||||
size_t dst_id = blockIdx.x;
|
||||
|
||||
shr[tid] = -INFINITY;
|
||||
shr_index[tid] = 0xFFFFFFFF;
|
||||
bool not_set = true;
|
||||
// Elements summed in this block range from dst_id * el_to_sum_per_block
|
||||
// to (dst_id + 1) * el_to_sum_per_block.
|
||||
size_t start_idx = dst_id * el_to_sum_per_block;
|
||||
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel);
|
||||
size_t idx = start_idx + tid;
|
||||
|
||||
while (idx < stop_idx) {
|
||||
// TODO: Fast version for the contiguous case.
|
||||
size_t strided_i = get_strided_index(idx, num_dims, dims, strides);
|
||||
if (not_set || src[strided_i] > shr[tid]) {
|
||||
shr[tid] = src[strided_i];
|
||||
// Assume that the reduction takes place over the last dimension which is contiguous.
|
||||
shr_index[tid] = idx % dims[num_dims - 1];
|
||||
not_set = false;
|
||||
}
|
||||
idx += blockDim.x;
|
||||
}
|
||||
|
||||
// Parallel reduction, see the slides:
|
||||
// https://www.olcf.ornl.gov/wp-content/uploads/2019/12/05_Atomics_Reductions_Warp_Shuffle.pdf
|
||||
// https://stackoverflow.com/questions/66078814/is-cuda-atomicadd-operation-faster-than-launch-another-kernel-when-we-do-reduce
|
||||
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
|
||||
__syncthreads();
|
||||
if (tid < s && shr[tid + s] > shr[tid]) {
|
||||
shr[tid] = shr[tid + s];
|
||||
shr_index[tid] = shr_index[tid + s];
|
||||
}
|
||||
}
|
||||
|
||||
if (tid == 0)
|
||||
dst[dst_id] = shr_index[0];
|
||||
}
|
||||
|
||||
#define FAST_OP(TYPENAME, MIN_NAME, MAX_NAME, ARGMIN_NAME, ARGMAX_NAME, SUM_NAME) \
|
||||
extern "C" __global__ void ARGMIN_NAME( \
|
||||
const size_t src_numel, const size_t el_to_sum_per_block, \
|
||||
const size_t num_dims, const size_t *info, const TYPENAME *src, \
|
||||
uint32_t *dst) { \
|
||||
fast_argmin(src_numel, el_to_sum_per_block, num_dims, info, src, dst); \
|
||||
} \
|
||||
extern "C" __global__ void ARGMAX_NAME( \
|
||||
const size_t src_numel, const size_t el_to_sum_per_block, \
|
||||
const size_t num_dims, const size_t *info, const TYPENAME *src, \
|
||||
uint32_t *dst) { \
|
||||
fast_argmax(src_numel, el_to_sum_per_block, num_dims, info, src, dst); \
|
||||
} \
|
||||
extern "C" __global__ void MIN_NAME( \
|
||||
const size_t src_numel, const size_t el_to_sum_per_block, \
|
||||
const size_t num_dims, const size_t *info, const TYPENAME *src, \
|
||||
@ -183,18 +292,19 @@ fast_min(const size_t src_numel, const size_t el_to_sum_per_block,
|
||||
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
SUM_OP(__nv_bfloat16, sum_bf16)
|
||||
FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_sum_bf16)
|
||||
FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_argmin_bf16, fast_argmax_bf16, fast_sum_bf16)
|
||||
#endif
|
||||
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
SUM_OP(__half, sum_f16)
|
||||
FAST_OP(__half, fast_min_f16, fast_max_f16, fast_sum_f16)
|
||||
FAST_OP(__half, fast_min_f16, fast_max_f16, fast_argmin_f16, fast_argmax_f16, fast_sum_f16)
|
||||
#endif
|
||||
|
||||
SUM_OP(float, sum_f32)
|
||||
SUM_OP(double, sum_f64)
|
||||
SUM_OP(uint32_t, sum_u32)
|
||||
|
||||
FAST_OP(float, fast_min_f32, fast_max_f32, fast_sum_f32)
|
||||
FAST_OP(double, fast_min_f64, fast_max_f64, fast_sum_f64)
|
||||
FAST_OP(uint32_t, fast_min_u32, fast_max_u32, fast_sum_u32)
|
||||
FAST_OP(float, fast_min_f32, fast_max_f32, fast_argmin_f32, fast_argmax_f32, fast_sum_f32)
|
||||
FAST_OP(double, fast_min_f64, fast_max_f64, fast_argmin_f64, fast_argmax_f64, fast_sum_f64)
|
||||
FAST_OP(uint32_t, fast_min_u32, fast_max_u32, fast_argmin_u32, fast_argmax_u32, fast_sum_u32)
|
||||
FAST_OP(uint8_t, fast_min_u8, fast_max_u8, fast_argmin_u8, fast_argmax_u8, fast_sum_u8)
|
||||
|
Reference in New Issue
Block a user