Add the elu cuda kernel. (#114)

This commit is contained in:
Laurent Mazare
2023-07-10 07:57:01 +01:00
committed by GitHub
parent 270997a055
commit bc3be6f9b0

View File

@ -36,12 +36,46 @@ __device__ __forceinline__ T gelu_fwd(T x) {
return static_cast<T>(0.5) * x * (static_cast<T>(1.0) + tanhg(static_cast<T>(M_2_SQRTPI * M_SQRT1_2) * alpha)); return static_cast<T>(0.5) * x * (static_cast<T>(1.0) + tanhg(static_cast<T>(M_2_SQRTPI * M_SQRT1_2) * alpha));
} }
template<typename T>
__device__ __forceinline__ T elu_fwd(T x, T alpha) {
if (x > static_cast<T>(0)) {
return x;
}
return alpha * (expg(x) - static_cast<T>(1));
}
template<typename T> template<typename T>
__device__ __forceinline__ T relu_fwd(T x) { __device__ __forceinline__ T relu_fwd(T x) {
T zero = 0.; T zero = 0.;
return maxg(x, zero); return maxg(x, zero);
} }
#define UNARY_OP1(TYPENAME, FN_NAME, FUNC) \
extern "C" __global__ void FN_NAME( \
const size_t numel, \
const size_t num_dims, \
const size_t *info, \
const TYPENAME param, \
const TYPENAME *inp, \
TYPENAME *out \
) { \
const size_t *dims = info; \
const size_t *strides = info + num_dims; \
if (is_contiguous(num_dims, dims, strides)) { \
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
TYPENAME x = inp ? inp[i] : out[i]; \
out[i] = FUNC; \
} \
} \
else { \
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
unsigned strided_i = get_strided_index(i, num_dims, dims, strides); \
TYPENAME x = inp ? inp[strided_i] : out[i]; \
out[i] = FUNC; \
} \
} \
} \
#if __CUDA_ARCH__ >= 800 #if __CUDA_ARCH__ >= 800
UNARY_OP(__nv_bfloat16, ucopy_bf16, x) UNARY_OP(__nv_bfloat16, ucopy_bf16, x)
@ -55,6 +89,7 @@ UNARY_OP(__nv_bfloat16, usqr_bf16, x*x)
UNARY_OP(__nv_bfloat16, usqrt_bf16, sqrtg(x)) UNARY_OP(__nv_bfloat16, usqrt_bf16, sqrtg(x))
UNARY_OP(__nv_bfloat16, ugelu_bf16, gelu_fwd(x)) UNARY_OP(__nv_bfloat16, ugelu_bf16, gelu_fwd(x))
UNARY_OP(__nv_bfloat16, urelu_bf16, relu_fwd(x)) UNARY_OP(__nv_bfloat16, urelu_bf16, relu_fwd(x))
UNARY_OP1(__nv_bfloat16, uelu_bf16, elu_fwd(x, param))
#endif #endif
#if __CUDA_ARCH__ >= 530 #if __CUDA_ARCH__ >= 530
@ -69,6 +104,7 @@ UNARY_OP(__half, usqr_f16, x*x)
UNARY_OP(__half, usqrt_f16, sqrtg(x)) UNARY_OP(__half, usqrt_f16, sqrtg(x))
UNARY_OP(__half, ugelu_f16, gelu_fwd(x)) UNARY_OP(__half, ugelu_f16, gelu_fwd(x))
UNARY_OP(__half, urelu_f16, relu_fwd(x)) UNARY_OP(__half, urelu_f16, relu_fwd(x))
UNARY_OP1(__half, uelu_f16, elu_fwd(x, param))
#endif #endif
UNARY_OP(uint8_t, ucopy_u8, x) UNARY_OP(uint8_t, ucopy_u8, x)
@ -95,3 +131,5 @@ UNARY_OP(float, ugelu_f32, gelu_fwd(x))
UNARY_OP(double, ugelu_f64, gelu_fwd(x)) UNARY_OP(double, ugelu_f64, gelu_fwd(x))
UNARY_OP(float, urelu_f32, relu_fwd(x)) UNARY_OP(float, urelu_f32, relu_fwd(x))
UNARY_OP(double, urelu_f64, relu_fwd(x)) UNARY_OP(double, urelu_f64, relu_fwd(x))
UNARY_OP1(float, uelu_f32, elu_fwd(x, param))
UNARY_OP1(double, uelu_f64, elu_fwd(x, param))