mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Add a couple cuda kernels from dfdx.
This commit is contained in:
4
kernels/README.md
Normal file
4
kernels/README.md
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
# candle-kernels
|
||||||
|
|
||||||
|
This crate contains CUDA kernels used from candle. Some of these implementations
|
||||||
|
come from the [dfdx crate](https://github.com/coreylowman/dfdx).
|
20
kernels/src/binary_add.cu
Normal file
20
kernels/src/binary_add.cu
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
#include "binary_op_macros.cuh"
|
||||||
|
|
||||||
|
struct BinaryAddOp {};
|
||||||
|
|
||||||
|
#if __CUDA_ARCH__ >= 530
|
||||||
|
BINARY_OP(__half, badd_fwd_f16, badd_bwd_lhs_f16, badd_bwd_rhs_f16, BinaryAddOp,
|
||||||
|
x + y,
|
||||||
|
1.0,
|
||||||
|
1.0)
|
||||||
|
#endif
|
||||||
|
|
||||||
|
BINARY_OP(float, badd_fwd_f32, badd_bwd_lhs_f32, badd_bwd_rhs_f32, BinaryAddOp,
|
||||||
|
x + y,
|
||||||
|
1.0,
|
||||||
|
1.0)
|
||||||
|
|
||||||
|
BINARY_OP(double, badd_fwd_f64, badd_bwd_lhs_f64, badd_bwd_rhs_f64, BinaryAddOp,
|
||||||
|
x + y,
|
||||||
|
1.0,
|
||||||
|
1.0)
|
21
kernels/src/binary_div.cu
Normal file
21
kernels/src/binary_div.cu
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
#include "binary_op_macros.cuh"
|
||||||
|
|
||||||
|
struct BinaryDivOp {};
|
||||||
|
|
||||||
|
#if __CUDA_ARCH__ >= 530
|
||||||
|
BINARY_OP(__half, bdiv_fwd_f16, bdiv_bwd_lhs_f16, bdiv_bwd_rhs_f16, BinaryDivOp,
|
||||||
|
x / y,
|
||||||
|
recipg(y),
|
||||||
|
-x / (y * y))
|
||||||
|
#endif
|
||||||
|
|
||||||
|
BINARY_OP(float, bdiv_fwd_f32, bdiv_bwd_lhs_f32, bdiv_bwd_rhs_f32, BinaryDivOp,
|
||||||
|
x / y,
|
||||||
|
recipg(y),
|
||||||
|
-x / (y * y))
|
||||||
|
|
||||||
|
BINARY_OP(double, bdiv_fwd_f64, bdiv_bwd_lhs_f64, bdiv_bwd_rhs_f64, BinaryDivOp,
|
||||||
|
x / y,
|
||||||
|
recipg(y),
|
||||||
|
-x / (y * y))
|
||||||
|
|
21
kernels/src/binary_mul.cu
Normal file
21
kernels/src/binary_mul.cu
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
#include "binary_op_macros.cuh"
|
||||||
|
|
||||||
|
struct BinaryMulKernalOp {};
|
||||||
|
|
||||||
|
#if __CUDA_ARCH__ >= 530
|
||||||
|
BINARY_OP(__half, bmul_fwd_f16, bmul_bwd_lhs_f16, bmul_bwd_rhs_f16, BinaryMulKernalOp,
|
||||||
|
x * y,
|
||||||
|
y,
|
||||||
|
x)
|
||||||
|
#endif
|
||||||
|
|
||||||
|
BINARY_OP(float, bmul_fwd_f32, bmul_bwd_lhs_f32, bmul_bwd_rhs_f32, BinaryMulKernalOp,
|
||||||
|
x * y,
|
||||||
|
y,
|
||||||
|
x)
|
||||||
|
|
||||||
|
BINARY_OP(double, bmul_fwd_f64, bmul_bwd_lhs_f64, bmul_bwd_rhs_f64, BinaryMulKernalOp,
|
||||||
|
x * y,
|
||||||
|
y,
|
||||||
|
x)
|
||||||
|
|
101
kernels/src/binary_op_macros.cuh
Normal file
101
kernels/src/binary_op_macros.cuh
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
#include "cuda_utils.cuh"
|
||||||
|
|
||||||
|
#define LONG_BINARY_OP(TYPENAME, FORWARD, BACKWARD_LHS, BACKWARD_RHS, OP_STRUCT, FUNC, DFDX, DFDY) \
|
||||||
|
extern "C" __global__ void FORWARD( \
|
||||||
|
const OP_STRUCT op, \
|
||||||
|
const size_t numel, \
|
||||||
|
const size_t num_dims, \
|
||||||
|
const size_t *info, \
|
||||||
|
const TYPENAME *lhs, \
|
||||||
|
const TYPENAME *rhs, \
|
||||||
|
TYPENAME *out \
|
||||||
|
) { \
|
||||||
|
const size_t *dims = info; \
|
||||||
|
const size_t *lhs_strides = info + num_dims; \
|
||||||
|
const size_t *rhs_strides = info + 2 * num_dims; \
|
||||||
|
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
|
||||||
|
unsigned int tmp_i = i; \
|
||||||
|
unsigned int lhs_i = 0; \
|
||||||
|
unsigned int rhs_i = 0; \
|
||||||
|
for (int d = num_dims - 1; d >= 0; d--) { \
|
||||||
|
unsigned int i_dim = tmp_i % dims[d]; \
|
||||||
|
lhs_i += i_dim * lhs_strides[d]; \
|
||||||
|
rhs_i += i_dim * rhs_strides[d]; \
|
||||||
|
tmp_i /= dims[d]; \
|
||||||
|
} \
|
||||||
|
TYPENAME x = lhs ? lhs[lhs_i] : out[i]; \
|
||||||
|
TYPENAME y = rhs ? rhs[rhs_i] : out[i]; \
|
||||||
|
TYPENAME fx; \
|
||||||
|
FUNC\
|
||||||
|
out[i] = fx; \
|
||||||
|
} \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
extern "C" __global__ void BACKWARD_LHS( \
|
||||||
|
const OP_STRUCT op, \
|
||||||
|
const size_t numel, \
|
||||||
|
const size_t num_dims, \
|
||||||
|
const size_t *info, \
|
||||||
|
const TYPENAME *lhs, \
|
||||||
|
TYPENAME *grad_lhs, \
|
||||||
|
const size_t chunk_len, \
|
||||||
|
const TYPENAME *rhs, \
|
||||||
|
const TYPENAME *grad_out \
|
||||||
|
) { \
|
||||||
|
const size_t *dims = info + 0 * num_dims; \
|
||||||
|
const size_t *out_strides = info + 1 * num_dims; \
|
||||||
|
const size_t *rhs_strides = info + 2 * num_dims; \
|
||||||
|
TYPENAME zero = 0.0; \
|
||||||
|
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
|
||||||
|
unsigned int tmp_i = i; \
|
||||||
|
unsigned int out_i = 0; \
|
||||||
|
unsigned int rhs_i = 0; \
|
||||||
|
for (int d = num_dims - 1; d >= 0; d--) { \
|
||||||
|
unsigned int i_dim = tmp_i % dims[d]; \
|
||||||
|
out_i += i_dim * out_strides[d]; \
|
||||||
|
rhs_i += i_dim * rhs_strides[d]; \
|
||||||
|
tmp_i /= dims[d]; \
|
||||||
|
} \
|
||||||
|
TYPENAME x = lhs ? lhs[i / chunk_len] : zero; \
|
||||||
|
TYPENAME y = rhs ? rhs[rhs_i] : zero; \
|
||||||
|
TYPENAME go = grad_out[out_i]; \
|
||||||
|
TYPENAME dfdx = (DFDX); \
|
||||||
|
chunk_sum(chunk_len, dfdx * go, grad_lhs); \
|
||||||
|
} \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
extern "C" __global__ void BACKWARD_RHS( \
|
||||||
|
const OP_STRUCT op, \
|
||||||
|
const size_t numel, \
|
||||||
|
const size_t num_dims, \
|
||||||
|
const size_t *info, \
|
||||||
|
const TYPENAME *lhs, \
|
||||||
|
const TYPENAME *rhs, \
|
||||||
|
TYPENAME *grad_rhs, \
|
||||||
|
const size_t chunk_len, \
|
||||||
|
const TYPENAME *grad_out \
|
||||||
|
) { \
|
||||||
|
const size_t *dims = info + 3 * num_dims; \
|
||||||
|
const size_t *out_strides = info + 4 * num_dims; \
|
||||||
|
const size_t *lhs_strides = info + 5 * num_dims; \
|
||||||
|
TYPENAME zero = 0.0; \
|
||||||
|
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
|
||||||
|
unsigned int tmp_i = i; \
|
||||||
|
unsigned int lhs_i = 0; \
|
||||||
|
unsigned int out_i = 0; \
|
||||||
|
for (int d = num_dims - 1; d >= 0; d--) { \
|
||||||
|
unsigned int i_dim = tmp_i % dims[d]; \
|
||||||
|
lhs_i += i_dim * lhs_strides[d]; \
|
||||||
|
out_i += i_dim * out_strides[d]; \
|
||||||
|
tmp_i /= dims[d]; \
|
||||||
|
} \
|
||||||
|
TYPENAME x = lhs ? lhs[lhs_i] : zero; \
|
||||||
|
TYPENAME y = rhs ? rhs[i / chunk_len] : zero; \
|
||||||
|
TYPENAME go = grad_out[out_i]; \
|
||||||
|
TYPENAME dfdy = (DFDY); \
|
||||||
|
chunk_sum(chunk_len, dfdy * go, grad_rhs); \
|
||||||
|
} \
|
||||||
|
}
|
||||||
|
|
||||||
|
#define BINARY_OP(TYPENAME, FORWARD, BACKWARD_LHS, BACKWARD_RHS, OP_STRUCT, FUNC, DFDX, DFDY) \
|
||||||
|
LONG_BINARY_OP(TYPENAME, FORWARD, BACKWARD_LHS, BACKWARD_RHS, OP_STRUCT, fx = (FUNC);, DFDX, DFDY)
|
21
kernels/src/binary_sub.cu
Normal file
21
kernels/src/binary_sub.cu
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
#include "binary_op_macros.cuh"
|
||||||
|
|
||||||
|
struct BinarySubKernelOp {};
|
||||||
|
|
||||||
|
#if __CUDA_ARCH__ >= 530
|
||||||
|
BINARY_OP(__half, bsub_fwd_f16, bsub_bwd_lhs_f16, bsub_bwd_rhs_f16, BinarySubKernelOp,
|
||||||
|
x - y,
|
||||||
|
1.0,
|
||||||
|
-1.0)
|
||||||
|
#endif
|
||||||
|
|
||||||
|
BINARY_OP(float, bsub_fwd_f32, bsub_bwd_lhs_f32, bsub_bwd_rhs_f32, BinarySubKernelOp,
|
||||||
|
x - y,
|
||||||
|
1.0,
|
||||||
|
-1.0)
|
||||||
|
|
||||||
|
BINARY_OP(double, bsub_fwd_f64, bsub_bwd_lhs_f64, bsub_bwd_rhs_f64, BinarySubKernelOp,
|
||||||
|
x - y,
|
||||||
|
1.0,
|
||||||
|
-1.0)
|
||||||
|
|
171
kernels/src/compatibility.cuh
Normal file
171
kernels/src/compatibility.cuh
Normal file
@ -0,0 +1,171 @@
|
|||||||
|
#include "cuda_fp16.h"
|
||||||
|
|
||||||
|
// Table showing which features are supported on which compute capability
|
||||||
|
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/#features-and-technical-specifications
|
||||||
|
|
||||||
|
// FIXME: the minimum compute capabilities are just guesses since the table is not specific enough
|
||||||
|
|
||||||
|
// #if __CUDA_ARCH__ < 600
|
||||||
|
// __device__ __forceinline__ __half __hmax(__half a, __half b) {
|
||||||
|
// return __float2half(fmaxf(__half2float(a), __half2float(b)));
|
||||||
|
// }
|
||||||
|
// __device__ __forceinline__ __half __hmin(__half a, __half b) {
|
||||||
|
// return __float2half(fminf(__half2float(a), __half2float(b)));
|
||||||
|
// }
|
||||||
|
// #endif
|
||||||
|
|
||||||
|
#if __CUDA_ARCH__ < 800
|
||||||
|
__device__ __forceinline__ __half __hmax_nan(__half a, __half b) {
|
||||||
|
// return __hisnan(a) ? a : (__hisnan(b) ? b : __hmax(a, b));
|
||||||
|
}
|
||||||
|
__device__ __forceinline__ __half __hmin_nan(__half a, __half b) {
|
||||||
|
// return __hisnan(a) ? a : (__hisnan(b) ? b : __hmin(a, b));
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if __CUDA_ARCH__ < 600
|
||||||
|
// Copied from https://docs.nvidia.com/cuda/cuda-c-programming-guide/#atomic-functions
|
||||||
|
__device__ double atomicAdd(double* address, double val) {
|
||||||
|
unsigned long long int* address_as_ull = (unsigned long long int*)address;
|
||||||
|
unsigned long long int old = *address_as_ull, assumed;
|
||||||
|
|
||||||
|
do {
|
||||||
|
assumed = old;
|
||||||
|
old = atomicCAS(address_as_ull, assumed,
|
||||||
|
__double_as_longlong(val +
|
||||||
|
__longlong_as_double(assumed)));
|
||||||
|
|
||||||
|
// Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
|
||||||
|
} while (assumed != old);
|
||||||
|
|
||||||
|
return __longlong_as_double(old);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
|
||||||
|
#if __CUDA_ARCH__ < 700
|
||||||
|
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#atomicadd
|
||||||
|
// The 16-bit __half floating-point version of atomicAdd() is only supported by devices of compute capability 7.x and higher.
|
||||||
|
// Solution adapted from https://github.com/torch/cutorch/blob/master/lib/THC/THCAtomics.cuh#L96-L119
|
||||||
|
__device__ __half atomicAdd(__half *address, __half val) {
|
||||||
|
// unsigned int *address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2));
|
||||||
|
// unsigned int old = *address_as_ui;
|
||||||
|
// unsigned int assumed;
|
||||||
|
// bool unaligned = (size_t) address & 2;
|
||||||
|
// do {
|
||||||
|
// assumed = old;
|
||||||
|
// unsigned int hsum;
|
||||||
|
// hsum = unaligned ? (old >> 16) : (old & 0xffff);
|
||||||
|
// hsum = __half_as_ushort(__ushort_as_half(hsum) + val);
|
||||||
|
// old = atomicCAS(address_as_ui, assumed,
|
||||||
|
// unaligned ? (old & 0xffff) | (hsum << 16) : (old & 0xffff0000) | hsum
|
||||||
|
// );
|
||||||
|
|
||||||
|
// } while (assumed != old);
|
||||||
|
// return __ushort_as_half(unaligned ? (old >> 16) : (old & 0xffff));
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
|
||||||
|
__device__ __forceinline__ __half atomicMaxf(__half* address, __half val) {
|
||||||
|
#if __CUDA_ARCH__ < 700
|
||||||
|
// On older GPUs we do not have access to atomicCAS for shorts, so we have to do some trickery.
|
||||||
|
// Solution adapted from https://github.com/torch/cutorch/blob/master/lib/THC/THCAtomics.cuh#L96-L119
|
||||||
|
unsigned int *address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2));
|
||||||
|
unsigned int old = *address_as_ui;
|
||||||
|
unsigned int assumed;
|
||||||
|
bool unaligned = (size_t) address & 2;
|
||||||
|
do {
|
||||||
|
assumed = old;
|
||||||
|
unsigned int hmax;
|
||||||
|
hmax = unaligned ? (old >> 16) : (old & 0xffff);
|
||||||
|
hmax = __half_as_ushort(__hmax_nan(val, __ushort_as_half(hmax)));
|
||||||
|
old = atomicCAS(address_as_ui, assumed,
|
||||||
|
unaligned ? (old & 0xffff) | (hmax << 16) : (old & 0xffff0000) | hmax
|
||||||
|
);
|
||||||
|
|
||||||
|
} while (assumed != old);
|
||||||
|
return __ushort_as_half(unaligned ? (old >> 16) : (old & 0xffff));
|
||||||
|
#else
|
||||||
|
// Based on https://docs.nvidia.com/cuda/cuda-c-programming-guide/#atomic-functions
|
||||||
|
unsigned short int* casted_address = (unsigned short int*)address;
|
||||||
|
unsigned short int old = *casted_address;
|
||||||
|
unsigned short int assumed;
|
||||||
|
do {
|
||||||
|
assumed = old;
|
||||||
|
old = atomicCAS(casted_address, assumed, __half_as_ushort(__hmax_nan(val, __ushort_as_half(assumed))));
|
||||||
|
// Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
|
||||||
|
} while (assumed != old);
|
||||||
|
return __ushort_as_half(old);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
// atomicMax is not implemented for floats,
|
||||||
|
// solution copied https://stackoverflow.com/questions/17399119/how-do-i-use-atomicmax-on-floating-point-values-in-cuda
|
||||||
|
__device__ __forceinline__ float atomicMaxf(float * addr, float value) {
|
||||||
|
if (signbit(value)) {
|
||||||
|
return __uint_as_float(atomicMin((unsigned int *)addr, __float_as_uint(value)));
|
||||||
|
} else {
|
||||||
|
return __int_as_float(atomicMax((int *)addr, __float_as_int(value)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ __forceinline__ double atomicMaxf(double * addr, double value) {
|
||||||
|
if (signbit(value)) {
|
||||||
|
return __longlong_as_double(atomicMin((unsigned long long int *)addr, __double_as_longlong(value)));
|
||||||
|
} else {
|
||||||
|
return __longlong_as_double(atomicMax((long long int *)addr, __double_as_longlong(value)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
__device__ __forceinline__ __half atomicMinf(__half* address, __half val) {
|
||||||
|
#if __CUDA_ARCH__ < 700
|
||||||
|
// On older GPUs we do not have access to atomicCAS for shorts, so we have to do some trickery.
|
||||||
|
// Solution adapted from https://github.com/torch/cutorch/blob/master/lib/THC/THCAtomics.cuh#L96-L119
|
||||||
|
unsigned int *address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2));
|
||||||
|
unsigned int old = *address_as_ui;
|
||||||
|
unsigned int assumed;
|
||||||
|
bool unaligned = (size_t) address & 2;
|
||||||
|
do {
|
||||||
|
assumed = old;
|
||||||
|
unsigned int hmin;
|
||||||
|
hmin = unaligned ? (old >> 16) : (old & 0xffff);
|
||||||
|
hmin = __half_as_ushort(__hmin_nan(val, __ushort_as_half(hmin)));
|
||||||
|
old = atomicCAS(address_as_ui, assumed,
|
||||||
|
unaligned ? (old & 0xffff) | (hmin << 16) : (old & 0xffff0000) | hmin
|
||||||
|
);
|
||||||
|
|
||||||
|
} while (assumed != old);
|
||||||
|
return __ushort_as_half(unaligned ? (old >> 16) : (old & 0xffff));
|
||||||
|
#else
|
||||||
|
// Based on https://docs.nvidia.com/cuda/cuda-c-programming-guide/#atomic-functions
|
||||||
|
unsigned short int* casted_address = (unsigned short int*)address;
|
||||||
|
unsigned short int old = *casted_address;
|
||||||
|
unsigned short int assumed;
|
||||||
|
do {
|
||||||
|
assumed = old;
|
||||||
|
old = atomicCAS(casted_address, assumed, __half_as_ushort(__hmin_nan(val, __ushort_as_half(assumed))));
|
||||||
|
// Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
|
||||||
|
} while (assumed != old);
|
||||||
|
return __ushort_as_half(old);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
// atomicMin is not implemented for floats,
|
||||||
|
// solution copied https://stackoverflow.com/questions/17399119/how-do-i-use-atomicmax-on-floating-point-values-in-cuda
|
||||||
|
__device__ __forceinline__ float atomicMinf(float * addr, float value) {
|
||||||
|
if (signbit(value)) {
|
||||||
|
return __uint_as_float(atomicMax((unsigned int *)addr, __float_as_uint(value)));
|
||||||
|
} else {
|
||||||
|
return __int_as_float(atomicMin((int *)addr, __float_as_int(value)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ __forceinline__ double atomicMinf(double * addr, double value) {
|
||||||
|
if (signbit(value)) {
|
||||||
|
return __longlong_as_double(atomicMax((unsigned long long int *)addr, __double_as_longlong(value)));
|
||||||
|
} else {
|
||||||
|
return __longlong_as_double(atomicMin((long long int *)addr, __double_as_longlong(value)));
|
||||||
|
}
|
||||||
|
}
|
138
kernels/src/cuda_utils.cuh
Normal file
138
kernels/src/cuda_utils.cuh
Normal file
@ -0,0 +1,138 @@
|
|||||||
|
#include "cuda_fp16.h"
|
||||||
|
#include "compatibility.cuh"
|
||||||
|
|
||||||
|
__device__ unsigned int get_strided_index(
|
||||||
|
unsigned int idx,
|
||||||
|
const size_t num_dims,
|
||||||
|
const size_t *dims,
|
||||||
|
const size_t *strides
|
||||||
|
) {
|
||||||
|
unsigned int strided_i = 0;
|
||||||
|
for (unsigned int d = 0; d < num_dims; d++) {
|
||||||
|
unsigned int dim_idx = num_dims - 1 - d;
|
||||||
|
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
|
||||||
|
idx /= dims[dim_idx];
|
||||||
|
}
|
||||||
|
return strided_i;
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ unsigned int restrided(
|
||||||
|
const unsigned int strided_i,
|
||||||
|
const size_t num_dims,
|
||||||
|
const size_t *dims,
|
||||||
|
const size_t *strides,
|
||||||
|
const size_t *new_strides
|
||||||
|
) {
|
||||||
|
unsigned int idx = 0;
|
||||||
|
for (int d = 0; d < num_dims; d++) {
|
||||||
|
idx += (strides[d] == 0 ? 0 : (strided_i / strides[d]) % dims[d]) * new_strides[d];
|
||||||
|
}
|
||||||
|
return idx;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sourced from https://graphics.stanford.edu/~seander/bithacks.html#RoundUpPowerOf2
|
||||||
|
// Input must be less than or equal to 2 ^ 16
|
||||||
|
// used in reductions
|
||||||
|
__device__ __forceinline__ unsigned int next_power_of_two(unsigned int v) {
|
||||||
|
v--;
|
||||||
|
v |= v >> 1;
|
||||||
|
v |= v >> 2;
|
||||||
|
v |= v >> 4;
|
||||||
|
v |= v >> 8;
|
||||||
|
v++;
|
||||||
|
return v;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Efficiently computes the sum of each chunk in "data" of size chunk_len, and
|
||||||
|
// stores the sums in out[i / chunk_len]
|
||||||
|
template<typename T>
|
||||||
|
__device__ void chunk_sum(
|
||||||
|
const size_t chunk_len,
|
||||||
|
const T data,
|
||||||
|
T* out
|
||||||
|
) {
|
||||||
|
__shared__ T buf[1024];
|
||||||
|
|
||||||
|
// assumes that threads where i >= numel have already exited
|
||||||
|
unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
unsigned int block_i = threadIdx.x;
|
||||||
|
|
||||||
|
// Fall back to atomicAdd if chunk_len is small to reduce overhead
|
||||||
|
if (chunk_len <= 2) {
|
||||||
|
atomicAdd(out + i / chunk_len, data);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
buf[block_i] = data;
|
||||||
|
|
||||||
|
unsigned int chunk_i = i % chunk_len;
|
||||||
|
unsigned int chunk_start = max((int)(block_i - chunk_i), 0);
|
||||||
|
unsigned int chunk_end = min((unsigned int)(block_i + chunk_len - chunk_i), blockDim.x);
|
||||||
|
|
||||||
|
chunk_i = block_i - chunk_start;
|
||||||
|
|
||||||
|
size_t max_chunk_len = min(chunk_end - chunk_start, blockDim.x);
|
||||||
|
size_t incr = next_power_of_two(max_chunk_len) >> 1;
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// Uses sequential addressing as discussed in
|
||||||
|
// https://developer.download.nvidia.com/assets/cuda/files/reduction.pdf
|
||||||
|
for (; incr > 0; incr >>= 1) {
|
||||||
|
unsigned int block_i_2 = block_i + incr;
|
||||||
|
|
||||||
|
if (block_i_2 < chunk_end && chunk_i < incr) {
|
||||||
|
// This is sound because __syncthreads and the conditions above
|
||||||
|
// ensure that no data races occur
|
||||||
|
buf[block_i] += buf[block_i_2];
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (block_i == chunk_start) {
|
||||||
|
atomicAdd(out + i / chunk_len, buf[block_i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ __forceinline__ bool isnang(float a) { return isnan(a); }
|
||||||
|
__device__ __forceinline__ bool isnang(double a) { return isnan(a); }
|
||||||
|
__device__ __forceinline__ float recipg(float a) { return 1.0 / a; }
|
||||||
|
__device__ __forceinline__ double recipg(double a) { return 1.0 / a; }
|
||||||
|
__device__ __forceinline__ float cosg(float a) { return cosf(a); }
|
||||||
|
__device__ __forceinline__ double cosg(double a) { return cos(a); }
|
||||||
|
__device__ __forceinline__ float sing(float a) { return sinf(a); }
|
||||||
|
__device__ __forceinline__ double sing(double a) { return sin(a); }
|
||||||
|
__device__ __forceinline__ float sqrtg(float a) { return sqrtf(a); }
|
||||||
|
__device__ __forceinline__ double sqrtg(double a) { return sqrt(a); }
|
||||||
|
__device__ __forceinline__ float powg(float a, float b) { return powf(a, b); }
|
||||||
|
__device__ __forceinline__ double powg(double a, double b) { return pow(a, b); }
|
||||||
|
__device__ __forceinline__ float tanhg(float a) { return tanhf(a); }
|
||||||
|
__device__ __forceinline__ double tanhg(double a) { return tanh(a); }
|
||||||
|
__device__ __forceinline__ float maxg(float a, float b) { return fmaxf(a, b); }
|
||||||
|
__device__ __forceinline__ double maxg(double a, double b) { return fmax(a, b); }
|
||||||
|
__device__ __forceinline__ float ming(float a, float b) { return fminf(a, b); }
|
||||||
|
__device__ __forceinline__ double ming(double a, double b) { return fmin(a, b); }
|
||||||
|
__device__ __forceinline__ float logg(float a) { return logf(a); }
|
||||||
|
__device__ __forceinline__ double logg(double a) { return log(a); }
|
||||||
|
__device__ __forceinline__ float expg(float a) { return expf(a); }
|
||||||
|
__device__ __forceinline__ double expg(double a) { return exp(a); }
|
||||||
|
__device__ __forceinline__ float absg(float a) { return fabsf(a); }
|
||||||
|
__device__ __forceinline__ double absg(double a) { return fabs(a); }
|
||||||
|
__device__ __forceinline__ float copysigng(float a, float b) { return copysignf(a, b); }
|
||||||
|
__device__ __forceinline__ double copysigng(double a, double b) { return copysign(a, b); }
|
||||||
|
|
||||||
|
#if __CUDA_ARCH__ >= 530
|
||||||
|
__device__ __forceinline__ __half powg(__half a, __half b) { return __float2half(powf(__half2float(a), __half2float(b))); }
|
||||||
|
__device__ __forceinline__ bool isnang(__half a) { return __hisnan(a); }
|
||||||
|
__device__ __forceinline__ __half sqrtg(__half a) { return hsqrt(a); }
|
||||||
|
__device__ __forceinline__ __half cosg(__half a) { return hcos(a); }
|
||||||
|
__device__ __forceinline__ __half sing(__half a) { return hsin(a); }
|
||||||
|
__device__ __forceinline__ __half recipg(__half a) { __half one = 1.0; return one / a; }
|
||||||
|
__device__ __forceinline__ __half maxg(__half a, __half b) { return __hmax_nan(a, b); }
|
||||||
|
__device__ __forceinline__ __half tanhg(__half a) { return __float2half(tanhf(__half2float(a))); }
|
||||||
|
__device__ __forceinline__ __half ming(__half a, __half b) { return __hmin_nan(a, b); }
|
||||||
|
__device__ __forceinline__ __half logg(__half a) { return hlog(a); }
|
||||||
|
__device__ __forceinline__ __half expg(__half a) { return hexp(a); }
|
||||||
|
__device__ __forceinline__ __half absg(__half a) { return __habs(a); }
|
||||||
|
__device__ __forceinline__ __half copysigng(__half a, __half b) { return __float2half(copysignf(__half2float(a), __half2float(b))); }
|
||||||
|
#endif
|
@ -1,2 +1,6 @@
|
|||||||
pub const AFFINE: &str = include_str!(concat!(env!("OUT_DIR"), "/affine.ptx"));
|
pub const AFFINE: &str = include_str!(concat!(env!("OUT_DIR"), "/affine.ptx"));
|
||||||
|
pub const BINARY_ADD: &str = include_str!(concat!(env!("OUT_DIR"), "/binary_add.ptx"));
|
||||||
|
pub const BINARY_DIV: &str = include_str!(concat!(env!("OUT_DIR"), "/binary_div.ptx"));
|
||||||
|
pub const BINARY_MUL: &str = include_str!(concat!(env!("OUT_DIR"), "/binary_mul.ptx"));
|
||||||
|
pub const BINARY_SUB: &str = include_str!(concat!(env!("OUT_DIR"), "/binary_sub.ptx"));
|
||||||
pub const FILL: &str = include_str!(concat!(env!("OUT_DIR"), "/fill.ptx"));
|
pub const FILL: &str = include_str!(concat!(env!("OUT_DIR"), "/fill.ptx"));
|
||||||
|
Reference in New Issue
Block a user