From 56ae71dd4c4228321893646b85622478f65e1fb1 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 23 Jun 2023 13:05:57 +0200 Subject: [PATCH] Address comments. --- kernels/src/unary.cu | 3 ++- src/op.rs | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/kernels/src/unary.cu b/kernels/src/unary.cu index 3f05b0bc..faf73bbf 100644 --- a/kernels/src/unary.cu +++ b/kernels/src/unary.cu @@ -32,7 +32,7 @@ UNARY_OP(__half, ucopy_f16, x) UNARY_OP(__half, uneg_f16, -x) UNARY_OP(__half, usqr_f16, x*x) UNARY_OP(__half, usqrt_f16, sqrtg(x)) -// UNARY_OP(__half, gelu_f16, gelu_fwd(x)) +UNARY_OP(__half, gelu_f16, gelu_fwd(x)) #endif UNARY_OP(float, ucopy_f32, x) @@ -44,3 +44,4 @@ UNARY_OP(float, usqr_f64, x*x) UNARY_OP(float, usqrt_f32, sqrtg(x)) UNARY_OP(float, usqrt_f64, sqrtg(x)) UNARY_OP(float, gelu_f32, gelu_fwd(x)) +UNARY_OP(float, gelu_f64, gelu_fwd(x)) diff --git a/src/op.rs b/src/op.rs index f2880580..4fae5458 100644 --- a/src/op.rs +++ b/src/op.rs @@ -143,14 +143,14 @@ impl UnaryOp for Sqrt { /// #[inline] pub fn gelu_f32(v: f32) -> f32 { - 0.5 * (v) + 0.5 * v * (1.0 + f32::tanh((2.0f32 / std::f32::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v))) } /// `gelu` operation /// #[inline] pub fn gelu_f64(v: f64) -> f64 { - 0.5 * (v) + 0.5 * v * (1.0 + f64::tanh((2.0f64 / std::f64::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v))) } impl UnaryOp for Gelu {