From f8848db001e54d8da55338d87b898b20a305a959 Mon Sep 17 00:00:00 2001 From: laurent Date: Fri, 23 Jun 2023 13:38:54 +0100 Subject: [PATCH] Fix the gelu kernel for f16. --- kernels/src/affine.cu | 63 +++++++++++++++++-------------------------- kernels/src/unary.cu | 5 ++-- 2 files changed, 26 insertions(+), 42 deletions(-) diff --git a/kernels/src/affine.cu b/kernels/src/affine.cu index af25aba1..63835940 100644 --- a/kernels/src/affine.cu +++ b/kernels/src/affine.cu @@ -1,42 +1,27 @@ #include "cuda_utils.cuh" +#include -extern "C" __global__ void affine_f32( - const size_t numel, - const size_t num_dims, - const size_t *info, - const float *x, - float *y, - const float mul, - const float add -) { - const size_t *dims = info; - const size_t *strides = info + num_dims; - unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i >= numel) { - return; - } - // This is likely to be very very slow, we should either optimize the contiguous case - // as a separate kernel, proceed by block, improve the stride computations (and probably - // do all of these). - unsigned strided_i = get_strided_index(i, num_dims, dims, strides); - y[strided_i] = x[i] * mul + add; -} +#define AFFINE_OP(TYPENAME, FN_NAME) \ +extern "C" __global__ void FN_NAME( \ + const size_t numel, \ + const size_t num_dims, \ + const size_t *info, \ + const TYPENAME *x, \ + TYPENAME *y, \ + const TYPENAME mul, \ + const TYPENAME add \ +) { \ + const size_t *dims = info; \ + const size_t *strides = info + num_dims; \ + unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; \ + if (i >= numel) { \ + return; \ + } \ + unsigned strided_i = get_strided_index(i, num_dims, dims, strides); \ + y[strided_i] = x[i] * mul + add; \ +} \ + +AFFINE_OP(float, affine_f32) +AFFINE_OP(double, affine_f64) +AFFINE_OP(uint32_t, affine_u32) -extern "C" __global__ void affine_f64( - const size_t numel, - const size_t num_dims, - const size_t *info, - const double *x, - double *y, - const double mul, - const double add -) { - const size_t *dims = info; - const size_t *strides = info + num_dims; - unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i >= numel) { - return; - } - unsigned strided_i = get_strided_index(i, num_dims, dims, strides); - y[strided_i] = x[i] * mul + add; -} diff --git a/kernels/src/unary.cu b/kernels/src/unary.cu index 3fe830e1..408f92a8 100644 --- a/kernels/src/unary.cu +++ b/kernels/src/unary.cu @@ -19,11 +19,10 @@ extern "C" __global__ void FN_NAME( \ template __device__ T gelu_fwd(T x) { - constexpr T fastCoeff = 0.044715; T x_sq = x * x; T x_cube = x_sq * x; - T alpha = x + fastCoeff * x_cube; - return 0.5 * x * (1.0 + tanhg(M_2_SQRTPI * M_SQRT1_2 * alpha)); + T alpha = x + static_cast(0.044715) * x_cube; + return static_cast(0.5) * x * (static_cast(1.0) + tanhg(static_cast(M_2_SQRTPI * M_SQRT1_2) * alpha)); }