mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 20:09:50 +00:00
Fix sigmoid gradient calculation and move sigmoid into a specialized op (#2114)
* add sigmoid op * small fix * add as a method on `Tensor` * implement gradient calculation for sigmoid * add sigmoid tests * we should have a specialized op for this * fix clippy * fix clippy 2 * Revert all previous commits in favor of a `CustomOp` based solution * use `CustomOp1` implementation * fix rustfmt * experimental add metal impl * add cuda kernel impl * fix fmt * Add a test + reduce some cuda duplication. --------- Co-authored-by: laurent <laurent.mazare@gmail.com>
This commit is contained in:
@ -60,6 +60,11 @@ __device__ __forceinline__ T silu_fwd(T x) {
|
||||
return x / (static_cast<T>(1) + expg(-x));
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
__device__ __forceinline__ T sigmoid_fwd(T x) {
|
||||
return recipg(static_cast<T>(1) + expg(-x));
|
||||
}
|
||||
|
||||
#define UNARY_OP1(TYPENAME, FN_NAME, FUNC) \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
const size_t numel, \
|
||||
@ -116,6 +121,7 @@ UNARY_OP1(__nv_bfloat16, uelu_bf16, elu_fwd(x, param))
|
||||
UNARY_OP(__nv_bfloat16, usilu_bf16, silu_fwd(x))
|
||||
UNARY_OP1(__nv_bfloat16, upowf_bf16, powg(x, param))
|
||||
UNARY_OP(__nv_bfloat16, usign_bf16, sign_(x))
|
||||
UNARY_OP(__nv_bfloat16, usigmoid_bf16, sigmoid_fwd(x))
|
||||
#endif
|
||||
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
@ -142,6 +148,7 @@ UNARY_OP1(__half, uelu_f16, elu_fwd(x, param))
|
||||
UNARY_OP(__half, usilu_f16, silu_fwd(x))
|
||||
UNARY_OP1(__half, upowf_f16, powg(x, param))
|
||||
UNARY_OP(__half, usign_f16, sign_(x))
|
||||
UNARY_OP(__half, usigmoid_f16, sigmoid_fwd(x))
|
||||
#endif
|
||||
|
||||
UNARY_OP(uint8_t, ucopy_u8, x)
|
||||
@ -193,3 +200,5 @@ UNARY_OP1(float, upowf_f32, powg(x, param))
|
||||
UNARY_OP1(double, upowf_f64, powg(x, param))
|
||||
UNARY_OP(float, usign_f32, sign_(x))
|
||||
UNARY_OP(double, usign_f64, sign_(x))
|
||||
UNARY_OP(float, usigmoid_f32, sigmoid_fwd(x))
|
||||
UNARY_OP(double, usigmoid_f64, sigmoid_fwd(x))
|
||||
|
Reference in New Issue
Block a user