Fix the silu cuda kernel. (#1710)

This commit is contained in:
Laurent Mazare
2024-02-14 11:08:18 +01:00
committed by GitHub
parent 2d5f2a728d
commit 121a71e01f

View File

@ -57,7 +57,7 @@ __device__ __forceinline__ T relu_fwd(T x) {
template<typename T>
__device__ __forceinline__ T silu_fwd(T x) {
return x / (static_cast<scalar_t>(1) + expg(-x));
return x / (static_cast<T>(1) + expg(-x));
}
#define UNARY_OP1(TYPENAME, FN_NAME, FUNC) \