Add the relu op.

This commit is contained in:
laurent
2023-06-28 09:38:54 +01:00
parent b805c4114b
commit 1ce3843cab
4 changed files with 47 additions and 9 deletions

View File

@ -26,13 +26,19 @@ extern "C" __global__ void FN_NAME( \
} \
template<typename T>
__device__ T gelu_fwd(T x) {
__device__ __forceinline__ T gelu_fwd(T x) {
T x_sq = x * x;
T x_cube = x_sq * x;
T alpha = x + static_cast<T>(0.044715) * x_cube;
return static_cast<T>(0.5) * x * (static_cast<T>(1.0) + tanhg(static_cast<T>(M_2_SQRTPI * M_SQRT1_2) * alpha));
}
template<typename T>
__device__ __forceinline__ T relu_fwd(T x) {
T zero = 0.;
return maxg(x, zero);
}
#if __CUDA_ARCH__ >= 530
UNARY_OP(__half, ucopy_f16, x)
@ -44,7 +50,8 @@ UNARY_OP(__half, ucos_f16, cosg(x))
UNARY_OP(__half, uabs_f16, absg(x))
UNARY_OP(__half, usqr_f16, x*x)
UNARY_OP(__half, usqrt_f16, sqrtg(x))
UNARY_OP(__half, gelu_f16, gelu_fwd(x))
UNARY_OP(__half, ugelu_f16, gelu_fwd(x))
UNARY_OP(__half, urelu_f16, relu_fwd(x))
#endif
UNARY_OP(float, ucopy_f32, x)
@ -65,5 +72,7 @@ UNARY_OP(float, usqr_f32, x*x)
UNARY_OP(double, usqr_f64, x*x)
UNARY_OP(float, usqrt_f32, sqrtg(x))
UNARY_OP(double, usqrt_f64, sqrtg(x))
UNARY_OP(float, gelu_f32, gelu_fwd(x))
UNARY_OP(double, gelu_f64, gelu_fwd(x))
UNARY_OP(float, ugelu_f32, gelu_fwd(x))
UNARY_OP(double, ugelu_f64, gelu_fwd(x))
UNARY_OP(float, urelu_f32, relu_fwd(x))
UNARY_OP(double, urelu_f64, relu_fwd(x))