mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Add the elu cuda kernel. (#114)
This commit is contained in:
@ -36,12 +36,46 @@ __device__ __forceinline__ T gelu_fwd(T x) {
|
|||||||
return static_cast<T>(0.5) * x * (static_cast<T>(1.0) + tanhg(static_cast<T>(M_2_SQRTPI * M_SQRT1_2) * alpha));
|
return static_cast<T>(0.5) * x * (static_cast<T>(1.0) + tanhg(static_cast<T>(M_2_SQRTPI * M_SQRT1_2) * alpha));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
__device__ __forceinline__ T elu_fwd(T x, T alpha) {
|
||||||
|
if (x > static_cast<T>(0)) {
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
return alpha * (expg(x) - static_cast<T>(1));
|
||||||
|
}
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
__device__ __forceinline__ T relu_fwd(T x) {
|
__device__ __forceinline__ T relu_fwd(T x) {
|
||||||
T zero = 0.;
|
T zero = 0.;
|
||||||
return maxg(x, zero);
|
return maxg(x, zero);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#define UNARY_OP1(TYPENAME, FN_NAME, FUNC) \
|
||||||
|
extern "C" __global__ void FN_NAME( \
|
||||||
|
const size_t numel, \
|
||||||
|
const size_t num_dims, \
|
||||||
|
const size_t *info, \
|
||||||
|
const TYPENAME param, \
|
||||||
|
const TYPENAME *inp, \
|
||||||
|
TYPENAME *out \
|
||||||
|
) { \
|
||||||
|
const size_t *dims = info; \
|
||||||
|
const size_t *strides = info + num_dims; \
|
||||||
|
if (is_contiguous(num_dims, dims, strides)) { \
|
||||||
|
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
|
||||||
|
TYPENAME x = inp ? inp[i] : out[i]; \
|
||||||
|
out[i] = FUNC; \
|
||||||
|
} \
|
||||||
|
} \
|
||||||
|
else { \
|
||||||
|
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
|
||||||
|
unsigned strided_i = get_strided_index(i, num_dims, dims, strides); \
|
||||||
|
TYPENAME x = inp ? inp[strided_i] : out[i]; \
|
||||||
|
out[i] = FUNC; \
|
||||||
|
} \
|
||||||
|
} \
|
||||||
|
} \
|
||||||
|
|
||||||
|
|
||||||
#if __CUDA_ARCH__ >= 800
|
#if __CUDA_ARCH__ >= 800
|
||||||
UNARY_OP(__nv_bfloat16, ucopy_bf16, x)
|
UNARY_OP(__nv_bfloat16, ucopy_bf16, x)
|
||||||
@ -55,6 +89,7 @@ UNARY_OP(__nv_bfloat16, usqr_bf16, x*x)
|
|||||||
UNARY_OP(__nv_bfloat16, usqrt_bf16, sqrtg(x))
|
UNARY_OP(__nv_bfloat16, usqrt_bf16, sqrtg(x))
|
||||||
UNARY_OP(__nv_bfloat16, ugelu_bf16, gelu_fwd(x))
|
UNARY_OP(__nv_bfloat16, ugelu_bf16, gelu_fwd(x))
|
||||||
UNARY_OP(__nv_bfloat16, urelu_bf16, relu_fwd(x))
|
UNARY_OP(__nv_bfloat16, urelu_bf16, relu_fwd(x))
|
||||||
|
UNARY_OP1(__nv_bfloat16, uelu_bf16, elu_fwd(x, param))
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if __CUDA_ARCH__ >= 530
|
#if __CUDA_ARCH__ >= 530
|
||||||
@ -69,6 +104,7 @@ UNARY_OP(__half, usqr_f16, x*x)
|
|||||||
UNARY_OP(__half, usqrt_f16, sqrtg(x))
|
UNARY_OP(__half, usqrt_f16, sqrtg(x))
|
||||||
UNARY_OP(__half, ugelu_f16, gelu_fwd(x))
|
UNARY_OP(__half, ugelu_f16, gelu_fwd(x))
|
||||||
UNARY_OP(__half, urelu_f16, relu_fwd(x))
|
UNARY_OP(__half, urelu_f16, relu_fwd(x))
|
||||||
|
UNARY_OP1(__half, uelu_f16, elu_fwd(x, param))
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
UNARY_OP(uint8_t, ucopy_u8, x)
|
UNARY_OP(uint8_t, ucopy_u8, x)
|
||||||
@ -95,3 +131,5 @@ UNARY_OP(float, ugelu_f32, gelu_fwd(x))
|
|||||||
UNARY_OP(double, ugelu_f64, gelu_fwd(x))
|
UNARY_OP(double, ugelu_f64, gelu_fwd(x))
|
||||||
UNARY_OP(float, urelu_f32, relu_fwd(x))
|
UNARY_OP(float, urelu_f32, relu_fwd(x))
|
||||||
UNARY_OP(double, urelu_f64, relu_fwd(x))
|
UNARY_OP(double, urelu_f64, relu_fwd(x))
|
||||||
|
UNARY_OP1(float, uelu_f32, elu_fwd(x, param))
|
||||||
|
UNARY_OP1(double, uelu_f64, elu_fwd(x, param))
|
||||||
|
Reference in New Issue
Block a user