Address comments.

This commit is contained in:
Nicolas Patry
2023-06-23 13:05:57 +02:00
parent fd21c708ab
commit 56ae71dd4c
2 changed files with 4 additions and 3 deletions

View File

@ -32,7 +32,7 @@ UNARY_OP(__half, ucopy_f16, x)
UNARY_OP(__half, uneg_f16, -x) UNARY_OP(__half, uneg_f16, -x)
UNARY_OP(__half, usqr_f16, x*x) UNARY_OP(__half, usqr_f16, x*x)
UNARY_OP(__half, usqrt_f16, sqrtg(x)) UNARY_OP(__half, usqrt_f16, sqrtg(x))
// UNARY_OP(__half, gelu_f16, gelu_fwd(x)) UNARY_OP(__half, gelu_f16, gelu_fwd(x))
#endif #endif
UNARY_OP(float, ucopy_f32, x) UNARY_OP(float, ucopy_f32, x)
@ -44,3 +44,4 @@ UNARY_OP(float, usqr_f64, x*x)
UNARY_OP(float, usqrt_f32, sqrtg(x)) UNARY_OP(float, usqrt_f32, sqrtg(x))
UNARY_OP(float, usqrt_f64, sqrtg(x)) UNARY_OP(float, usqrt_f64, sqrtg(x))
UNARY_OP(float, gelu_f32, gelu_fwd(x)) UNARY_OP(float, gelu_f32, gelu_fwd(x))
UNARY_OP(float, gelu_f64, gelu_fwd(x))

View File

@ -143,14 +143,14 @@ impl UnaryOp for Sqrt {
/// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions> /// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions>
#[inline] #[inline]
pub fn gelu_f32(v: f32) -> f32 { pub fn gelu_f32(v: f32) -> f32 {
0.5 * (v) 0.5 * v
* (1.0 + f32::tanh((2.0f32 / std::f32::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v))) * (1.0 + f32::tanh((2.0f32 / std::f32::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)))
} }
/// `gelu` operation /// `gelu` operation
/// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions> /// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions>
#[inline] #[inline]
pub fn gelu_f64(v: f64) -> f64 { pub fn gelu_f64(v: f64) -> f64 {
0.5 * (v) 0.5 * v
* (1.0 + f64::tanh((2.0f64 / std::f64::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v))) * (1.0 + f64::tanh((2.0f64 / std::f64::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)))
} }
impl UnaryOp for Gelu { impl UnaryOp for Gelu {