mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
Fix the silu cuda kernel. (#1710)
This commit is contained in:
@ -57,7 +57,7 @@ __device__ __forceinline__ T relu_fwd(T x) {
|
||||
|
||||
template<typename T>
|
||||
__device__ __forceinline__ T silu_fwd(T x) {
|
||||
return x / (static_cast<scalar_t>(1) + expg(-x));
|
||||
return x / (static_cast<T>(1) + expg(-x));
|
||||
}
|
||||
|
||||
#define UNARY_OP1(TYPENAME, FN_NAME, FUNC) \
|
||||
|
Reference in New Issue
Block a user