mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Fix the silu cuda kernel. (#1710)
This commit is contained in:
@ -57,7 +57,7 @@ __device__ __forceinline__ T relu_fwd(T x) {
|
|||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
__device__ __forceinline__ T silu_fwd(T x) {
|
__device__ __forceinline__ T silu_fwd(T x) {
|
||||||
return x / (static_cast<scalar_t>(1) + expg(-x));
|
return x / (static_cast<T>(1) + expg(-x));
|
||||||
}
|
}
|
||||||
|
|
||||||
#define UNARY_OP1(TYPENAME, FN_NAME, FUNC) \
|
#define UNARY_OP1(TYPENAME, FN_NAME, FUNC) \
|
||||||
|
Reference in New Issue
Block a user