diff --git a/candle-kernels/src/unary.cu b/candle-kernels/src/unary.cu index a9847256..09e8dd73 100644 --- a/candle-kernels/src/unary.cu +++ b/candle-kernels/src/unary.cu @@ -36,12 +36,46 @@ __device__ __forceinline__ T gelu_fwd(T x) { return static_cast(0.5) * x * (static_cast(1.0) + tanhg(static_cast(M_2_SQRTPI * M_SQRT1_2) * alpha)); } +template +__device__ __forceinline__ T elu_fwd(T x, T alpha) { + if (x > static_cast(0)) { + return x; + } + return alpha * (expg(x) - static_cast(1)); +} + template __device__ __forceinline__ T relu_fwd(T x) { T zero = 0.; return maxg(x, zero); } +#define UNARY_OP1(TYPENAME, FN_NAME, FUNC) \ +extern "C" __global__ void FN_NAME( \ + const size_t numel, \ + const size_t num_dims, \ + const size_t *info, \ + const TYPENAME param, \ + const TYPENAME *inp, \ + TYPENAME *out \ +) { \ + const size_t *dims = info; \ + const size_t *strides = info + num_dims; \ + if (is_contiguous(num_dims, dims, strides)) { \ + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \ + TYPENAME x = inp ? inp[i] : out[i]; \ + out[i] = FUNC; \ + } \ + } \ + else { \ + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \ + unsigned strided_i = get_strided_index(i, num_dims, dims, strides); \ + TYPENAME x = inp ? inp[strided_i] : out[i]; \ + out[i] = FUNC; \ + } \ + } \ +} \ + #if __CUDA_ARCH__ >= 800 UNARY_OP(__nv_bfloat16, ucopy_bf16, x) @@ -55,6 +89,7 @@ UNARY_OP(__nv_bfloat16, usqr_bf16, x*x) UNARY_OP(__nv_bfloat16, usqrt_bf16, sqrtg(x)) UNARY_OP(__nv_bfloat16, ugelu_bf16, gelu_fwd(x)) UNARY_OP(__nv_bfloat16, urelu_bf16, relu_fwd(x)) +UNARY_OP1(__nv_bfloat16, uelu_bf16, elu_fwd(x, param)) #endif #if __CUDA_ARCH__ >= 530 @@ -69,6 +104,7 @@ UNARY_OP(__half, usqr_f16, x*x) UNARY_OP(__half, usqrt_f16, sqrtg(x)) UNARY_OP(__half, ugelu_f16, gelu_fwd(x)) UNARY_OP(__half, urelu_f16, relu_fwd(x)) +UNARY_OP1(__half, uelu_f16, elu_fwd(x, param)) #endif UNARY_OP(uint8_t, ucopy_u8, x) @@ -95,3 +131,5 @@ UNARY_OP(float, ugelu_f32, gelu_fwd(x)) UNARY_OP(double, ugelu_f64, gelu_fwd(x)) UNARY_OP(float, urelu_f32, relu_fwd(x)) UNARY_OP(double, urelu_f64, relu_fwd(x)) +UNARY_OP1(float, uelu_f32, elu_fwd(x, param)) +UNARY_OP1(double, uelu_f64, elu_fwd(x, param))