mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Add an erf based gelu op (#900)
* Erf based gelu. * Add the erf backed gelu. * Test the new gelu op (which is not gelu_new).
This commit is contained in:
@ -129,6 +129,10 @@ __device__ __forceinline__ float powg(float a, float b) { return powf(a, b); }
|
||||
__device__ __forceinline__ double powg(double a, double b) { return pow(a, b); }
|
||||
__device__ __forceinline__ float tanhg(float a) { return tanhf(a); }
|
||||
__device__ __forceinline__ double tanhg(double a) { return tanh(a); }
|
||||
__device__ __forceinline__ float erfg(float a) { return erff(a); }
|
||||
__device__ __forceinline__ double erfg(double a) { return erf(a); }
|
||||
__device__ __forceinline__ float normcdfg(float a) { return normcdff(a); }
|
||||
__device__ __forceinline__ double normcdfg(double a) { return normcdf(a); }
|
||||
__device__ __forceinline__ float maxg(float a, float b) { return fmaxf(a, b); }
|
||||
__device__ __forceinline__ double maxg(double a, double b) { return fmax(a, b); }
|
||||
__device__ __forceinline__ float ming(float a, float b) { return fminf(a, b); }
|
||||
@ -157,6 +161,8 @@ __device__ __forceinline__ __half sing(__half a) { return hsin(a); }
|
||||
__device__ __forceinline__ __half recipg(__half a) { __half one = 1.0; return one / a; }
|
||||
__device__ __forceinline__ __half maxg(__half a, __half b) { return __hmax_nan(a, b); }
|
||||
__device__ __forceinline__ __half tanhg(__half a) { return __float2half(tanhf(__half2float(a))); }
|
||||
__device__ __forceinline__ __half erfg(__half a) { return __float2half(erff(__half2float(a))); }
|
||||
__device__ __forceinline__ __half normcdfg(__half a) { return __float2half(normcdff(__half2float(a))); }
|
||||
__device__ __forceinline__ __half ming(__half a, __half b) { return __hmin_nan(a, b); }
|
||||
__device__ __forceinline__ __half logg(__half a) { return hlog(a); }
|
||||
__device__ __forceinline__ __half expg(__half a) { return hexp(a); }
|
||||
@ -173,6 +179,8 @@ __device__ __forceinline__ __nv_bfloat16 sing(__nv_bfloat16 a) { return hsin(a);
|
||||
__device__ __forceinline__ __nv_bfloat16 recipg(__nv_bfloat16 a) { __nv_bfloat16 one = 1.0; return one / a; }
|
||||
__device__ __forceinline__ __nv_bfloat16 maxg(__nv_bfloat16 a, __nv_bfloat16 b) { return __hmax_nan(a, b); }
|
||||
__device__ __forceinline__ __nv_bfloat16 tanhg(__nv_bfloat16 a) { return __float2bfloat16(tanhf(__bfloat162float(a))); }
|
||||
__device__ __forceinline__ __nv_bfloat16 erfg(__nv_bfloat16 a) { return __float2bfloat16(erff(__bfloat162float(a))); }
|
||||
__device__ __forceinline__ __nv_bfloat16 normcdfg(__nv_bfloat16 a) { return __float2bfloat16(normcdff(__bfloat162float(a))); }
|
||||
__device__ __forceinline__ __nv_bfloat16 ming(__nv_bfloat16 a, __nv_bfloat16 b) { return __hmin_nan(a, b); }
|
||||
__device__ __forceinline__ __nv_bfloat16 logg(__nv_bfloat16 a) { return hlog(a); }
|
||||
__device__ __forceinline__ __nv_bfloat16 expg(__nv_bfloat16 a) { return hexp(a); }
|
||||
|
@ -28,6 +28,11 @@ extern "C" __global__ void FN_NAME( \
|
||||
} \
|
||||
} \
|
||||
|
||||
template<typename T>
|
||||
__device__ __forceinline__ T gelu_erf_fwd(T x) {
|
||||
return x * normcdfg(x);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
__device__ __forceinline__ T gelu_fwd(T x) {
|
||||
T x_sq = x * x;
|
||||
@ -86,10 +91,13 @@ UNARY_OP(__nv_bfloat16, ulog_bf16, logg(x))
|
||||
UNARY_OP(__nv_bfloat16, usin_bf16, sing(x))
|
||||
UNARY_OP(__nv_bfloat16, ucos_bf16, cosg(x))
|
||||
UNARY_OP(__nv_bfloat16, utanh_bf16, tanhg(x))
|
||||
UNARY_OP(__nv_bfloat16, uerf_bf16, erfg(x))
|
||||
UNARY_OP(__nv_bfloat16, unormcdf_bf16, normcdfg(x))
|
||||
UNARY_OP(__nv_bfloat16, uabs_bf16, absg(x))
|
||||
UNARY_OP(__nv_bfloat16, usqr_bf16, x*x)
|
||||
UNARY_OP(__nv_bfloat16, usqrt_bf16, sqrtg(x))
|
||||
UNARY_OP(__nv_bfloat16, ugelu_bf16, gelu_fwd(x))
|
||||
UNARY_OP(__nv_bfloat16, ugelu_erf_bf16, gelu_erf_fwd(x))
|
||||
UNARY_OP(__nv_bfloat16, urelu_bf16, relu_fwd(x))
|
||||
UNARY_OP1(__nv_bfloat16, uelu_bf16, elu_fwd(x, param))
|
||||
UNARY_OP1(__nv_bfloat16, upowf_bf16, powg(x, param))
|
||||
@ -104,10 +112,13 @@ UNARY_OP(__half, ulog_f16, logg(x))
|
||||
UNARY_OP(__half, usin_f16, sing(x))
|
||||
UNARY_OP(__half, ucos_f16, cosg(x))
|
||||
UNARY_OP(__half, utanh_f16, tanhg(x))
|
||||
UNARY_OP(__half, uerf_f16, erfg(x))
|
||||
UNARY_OP(__half, unormcdf_f16, normcdfg(x))
|
||||
UNARY_OP(__half, uabs_f16, absg(x))
|
||||
UNARY_OP(__half, usqr_f16, x*x)
|
||||
UNARY_OP(__half, usqrt_f16, sqrtg(x))
|
||||
UNARY_OP(__half, ugelu_f16, gelu_fwd(x))
|
||||
UNARY_OP(__half, ugelu_erf_f16, gelu_erf_fwd(x))
|
||||
UNARY_OP(__half, urelu_f16, relu_fwd(x))
|
||||
UNARY_OP1(__half, uelu_f16, elu_fwd(x, param))
|
||||
UNARY_OP1(__half, upowf_f16, powg(x, param))
|
||||
@ -131,6 +142,10 @@ UNARY_OP(float, ucos_f32, cosg(x))
|
||||
UNARY_OP(double, ucos_f64, cosg(x))
|
||||
UNARY_OP(float, utanh_f32, tanhg(x))
|
||||
UNARY_OP(double, utanh_f64, tanhg(x))
|
||||
UNARY_OP(float, uerf_f32, erfg(x))
|
||||
UNARY_OP(double, uerf_f64, erfg(x))
|
||||
UNARY_OP(float, unormcdf_f32, normcdfg(x))
|
||||
UNARY_OP(double, unormcdf_f64, normcdfg(x))
|
||||
UNARY_OP(float, uabs_f32, absg(x))
|
||||
UNARY_OP(double, uabs_f64, absg(x))
|
||||
UNARY_OP(float, usqr_f32, x*x)
|
||||
@ -139,6 +154,8 @@ UNARY_OP(float, usqrt_f32, sqrtg(x))
|
||||
UNARY_OP(double, usqrt_f64, sqrtg(x))
|
||||
UNARY_OP(float, ugelu_f32, gelu_fwd(x))
|
||||
UNARY_OP(double, ugelu_f64, gelu_fwd(x))
|
||||
UNARY_OP(float, ugelu_erf_f32, gelu_erf_fwd(x))
|
||||
UNARY_OP(double, ugelu_erf_f64, gelu_erf_fwd(x))
|
||||
UNARY_OP(float, urelu_f32, relu_fwd(x))
|
||||
UNARY_OP(double, urelu_f64, relu_fwd(x))
|
||||
UNARY_OP1(float, uelu_f32, elu_fwd(x, param))
|
||||
|
Reference in New Issue
Block a user