mirror of
https://github.com/huggingface/candle.git
synced 2025-06-22 12:28:06 +00:00
Fix the gelu kernel for f16.
This commit is contained in:
@ -19,11 +19,10 @@ extern "C" __global__ void FN_NAME( \
|
||||
|
||||
template<typename T>
|
||||
__device__ T gelu_fwd(T x) {
|
||||
constexpr T fastCoeff = 0.044715;
|
||||
T x_sq = x * x;
|
||||
T x_cube = x_sq * x;
|
||||
T alpha = x + fastCoeff * x_cube;
|
||||
return 0.5 * x * (1.0 + tanhg(M_2_SQRTPI * M_SQRT1_2 * alpha));
|
||||
T alpha = x + static_cast<T>(0.044715) * x_cube;
|
||||
return static_cast<T>(0.5) * x * (static_cast<T>(1.0) + tanhg(static_cast<T>(M_2_SQRTPI * M_SQRT1_2) * alpha));
|
||||
}
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user