Fix the gelu kernel for f16.

This commit is contained in:
laurent
2023-06-23 13:38:54 +01:00
parent db5526d51a
commit f8848db001
2 changed files with 26 additions and 42 deletions

View File

@ -19,11 +19,10 @@ extern "C" __global__ void FN_NAME( \
template<typename T>
__device__ T gelu_fwd(T x) {
constexpr T fastCoeff = 0.044715;
T x_sq = x * x;
T x_cube = x_sq * x;
T alpha = x + fastCoeff * x_cube;
return 0.5 * x * (1.0 + tanhg(M_2_SQRTPI * M_SQRT1_2 * alpha));
T alpha = x + static_cast<T>(0.044715) * x_cube;
return static_cast<T>(0.5) * x * (static_cast<T>(1.0) + tanhg(static_cast<T>(M_2_SQRTPI * M_SQRT1_2) * alpha));
}