diff --git a/candle-kernels/src/unary.cu b/candle-kernels/src/unary.cu index 2256c6bb..74ba1fac 100644 --- a/candle-kernels/src/unary.cu +++ b/candle-kernels/src/unary.cu @@ -57,7 +57,7 @@ __device__ __forceinline__ T relu_fwd(T x) { template __device__ __forceinline__ T silu_fwd(T x) { - return x / (static_cast(1) + expg(-x)); + return x / (static_cast(1) + expg(-x)); } #define UNARY_OP1(TYPENAME, FN_NAME, FUNC) \