From fd21c708ab05fbe0396add46dafbb862e1b81f84 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 22 Jun 2023 21:56:46 +0200 Subject: [PATCH 1/3] Creating Gelu op (no backward). --- kernels/src/unary.cu | 12 ++++++++++++ src/op.rs | 28 ++++++++++++++++++++++++++++ src/tensor.rs | 3 +++ 3 files changed, 43 insertions(+) diff --git a/kernels/src/unary.cu b/kernels/src/unary.cu index 5bee3725..3f05b0bc 100644 --- a/kernels/src/unary.cu +++ b/kernels/src/unary.cu @@ -17,11 +17,22 @@ extern "C" __global__ void FN_NAME( \ } \ } \ +template +__device__ T gelu_fwd(T x) { + constexpr T fastCoeff = 0.044715; + T x_sq = x * x; + T x_cube = x_sq * x; + T alpha = x + fastCoeff * x_cube; + return 0.5 * x * (1.0 + tanhg(M_2_SQRTPI * M_SQRT1_2 * alpha)); +} + + #if __CUDA_ARCH__ >= 530 UNARY_OP(__half, ucopy_f16, x) UNARY_OP(__half, uneg_f16, -x) UNARY_OP(__half, usqr_f16, x*x) UNARY_OP(__half, usqrt_f16, sqrtg(x)) +// UNARY_OP(__half, gelu_f16, gelu_fwd(x)) #endif UNARY_OP(float, ucopy_f32, x) @@ -32,3 +43,4 @@ UNARY_OP(float, usqr_f32, x*x) UNARY_OP(float, usqr_f64, x*x) UNARY_OP(float, usqrt_f32, sqrtg(x)) UNARY_OP(float, usqrt_f64, sqrtg(x)) +UNARY_OP(float, gelu_f32, gelu_fwd(x)) diff --git a/src/op.rs b/src/op.rs index 45fe97a4..f2880580 100644 --- a/src/op.rs +++ b/src/op.rs @@ -22,6 +22,7 @@ pub(crate) enum Op { Sqrt(Tensor), ToDevice(Tensor), Transpose(Tensor, usize, usize), + Gelu(Tensor), // TODO: Support for custom ops. } @@ -52,6 +53,7 @@ pub(crate) struct Sub; pub(crate) struct Neg; pub(crate) struct Sqr; pub(crate) struct Sqrt; +pub(crate) struct Gelu; impl BinaryOp for Add { const NAME: &'static str = "add"; @@ -136,3 +138,29 @@ impl UnaryOp for Sqrt { const KERNEL_F32: &'static str = "usqrt_f32"; const KERNEL_F64: &'static str = "usqrt_f64"; } + +/// `gelu` operation +/// +#[inline] +pub fn gelu_f32(v: f32) -> f32 { + 0.5 * (v) + * (1.0 + f32::tanh((2.0f32 / std::f32::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v))) +} +/// `gelu` operation +/// +#[inline] +pub fn gelu_f64(v: f64) -> f64 { + 0.5 * (v) + * (1.0 + f64::tanh((2.0f64 / std::f64::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v))) +} +impl UnaryOp for Gelu { + const NAME: &'static str = "gelu"; + fn f32(v1: f32) -> f32 { + gelu_f32(v1) + } + fn f64(v1: f64) -> f64 { + gelu_f64(v1) + } + const KERNEL_F32: &'static str = "gelu_f32"; + const KERNEL_F64: &'static str = "gelu_f64"; +} diff --git a/src/tensor.rs b/src/tensor.rs index 1e411dcc..b69319d3 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -240,6 +240,7 @@ impl Tensor { unary_op!(neg, Neg); unary_op!(sqr, Sqr); unary_op!(sqrt, Sqrt); + unary_op!(gelu, Gelu); pub fn to_scalar(&self) -> Result { if self.rank() != 0 { return Err(Error::UnexpectedNumberOfDims { @@ -766,6 +767,7 @@ impl Tensor { | Op::Transpose(node, _, _) | Op::Sqr(node) | Op::Sqrt(node) + | Op::Gelu(node) | Op::Neg(node) => { let (tg, nodes) = walk(node, nodes, already_seen); track_grad |= tg; @@ -854,6 +856,7 @@ impl Tensor { *sum_grad = sum_grad.add(&arg_grad)? } Op::Reshape(_arg) => return Err(Error::BackwardNotSupported { op: "reshape" }), + Op::Gelu(_) => return Err(Error::BackwardNotSupported { op: "reshape" }), Op::Sqr(arg) => { let arg_grad = arg.mul(&grad)?.affine(2., 0.)?; let sum_grad = grads.or_insert(arg)?; From 56ae71dd4c4228321893646b85622478f65e1fb1 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 23 Jun 2023 13:05:57 +0200 Subject: [PATCH 2/3] Address comments. --- kernels/src/unary.cu | 3 ++- src/op.rs | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/kernels/src/unary.cu b/kernels/src/unary.cu index 3f05b0bc..faf73bbf 100644 --- a/kernels/src/unary.cu +++ b/kernels/src/unary.cu @@ -32,7 +32,7 @@ UNARY_OP(__half, ucopy_f16, x) UNARY_OP(__half, uneg_f16, -x) UNARY_OP(__half, usqr_f16, x*x) UNARY_OP(__half, usqrt_f16, sqrtg(x)) -// UNARY_OP(__half, gelu_f16, gelu_fwd(x)) +UNARY_OP(__half, gelu_f16, gelu_fwd(x)) #endif UNARY_OP(float, ucopy_f32, x) @@ -44,3 +44,4 @@ UNARY_OP(float, usqr_f64, x*x) UNARY_OP(float, usqrt_f32, sqrtg(x)) UNARY_OP(float, usqrt_f64, sqrtg(x)) UNARY_OP(float, gelu_f32, gelu_fwd(x)) +UNARY_OP(float, gelu_f64, gelu_fwd(x)) diff --git a/src/op.rs b/src/op.rs index f2880580..4fae5458 100644 --- a/src/op.rs +++ b/src/op.rs @@ -143,14 +143,14 @@ impl UnaryOp for Sqrt { /// #[inline] pub fn gelu_f32(v: f32) -> f32 { - 0.5 * (v) + 0.5 * v * (1.0 + f32::tanh((2.0f32 / std::f32::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v))) } /// `gelu` operation /// #[inline] pub fn gelu_f64(v: f64) -> f64 { - 0.5 * (v) + 0.5 * v * (1.0 + f64::tanh((2.0f64 / std::f64::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v))) } impl UnaryOp for Gelu { From 09b7731b8d2e9cc3700b097792be7f1daabc9095 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 23 Jun 2023 13:10:26 +0200 Subject: [PATCH 3/3] Fix unary op. --- kernels/src/unary.cu | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/kernels/src/unary.cu b/kernels/src/unary.cu index faf73bbf..3fe830e1 100644 --- a/kernels/src/unary.cu +++ b/kernels/src/unary.cu @@ -36,12 +36,12 @@ UNARY_OP(__half, gelu_f16, gelu_fwd(x)) #endif UNARY_OP(float, ucopy_f32, x) -UNARY_OP(float, ucopy_f64, x) +UNARY_OP(double, ucopy_f64, x) UNARY_OP(float, uneg_f32, -x) -UNARY_OP(float, uneg_f64, -x) +UNARY_OP(double, uneg_f64, -x) UNARY_OP(float, usqr_f32, x*x) -UNARY_OP(float, usqr_f64, x*x) +UNARY_OP(double, usqr_f64, x*x) UNARY_OP(float, usqrt_f32, sqrtg(x)) -UNARY_OP(float, usqrt_f64, sqrtg(x)) +UNARY_OP(double, usqrt_f64, sqrtg(x)) UNARY_OP(float, gelu_f32, gelu_fwd(x)) -UNARY_OP(float, gelu_f64, gelu_fwd(x)) +UNARY_OP(double, gelu_f64, gelu_fwd(x))