mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
Add support for "sign" on tensors (#2012)
* add the sign unary operator * remove uneeded import * remove uneeded import * undo formatting * undo formatting * remove unnecessary redefintion * allow gradient to flow through for sign and round * fix cpu ops to ensure that negzero and positive zero are handled properly * clippy fixes * Properly avoid gradient tracking. * Use a branchless version. --------- Co-authored-by: laurent <laurent.mazare@gmail.com>
This commit is contained in:
@ -86,6 +86,11 @@ extern "C" __global__ void FN_NAME( \
|
||||
} \
|
||||
} \
|
||||
|
||||
template<typename T>
|
||||
__device__ T sign_(T t) {
|
||||
return static_cast<T>(t > static_cast<T>(0)) - static_cast<T>(t < static_cast<T>(0));
|
||||
}
|
||||
|
||||
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
UNARY_OP(__nv_bfloat16, ucopy_bf16, x)
|
||||
@ -110,6 +115,7 @@ UNARY_OP(__nv_bfloat16, urelu_bf16, relu_fwd(x))
|
||||
UNARY_OP1(__nv_bfloat16, uelu_bf16, elu_fwd(x, param))
|
||||
UNARY_OP(__nv_bfloat16, usilu_bf16, silu_fwd(x))
|
||||
UNARY_OP1(__nv_bfloat16, upowf_bf16, powg(x, param))
|
||||
UNARY_OP(__nv_bfloat16, usign_bf16, sign_(x))
|
||||
#endif
|
||||
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
@ -135,6 +141,7 @@ UNARY_OP(__half, urelu_f16, relu_fwd(x))
|
||||
UNARY_OP1(__half, uelu_f16, elu_fwd(x, param))
|
||||
UNARY_OP(__half, usilu_f16, silu_fwd(x))
|
||||
UNARY_OP1(__half, upowf_f16, powg(x, param))
|
||||
UNARY_OP(__half, usign_f16, sign_(x))
|
||||
#endif
|
||||
|
||||
UNARY_OP(uint8_t, ucopy_u8, x)
|
||||
@ -184,3 +191,5 @@ UNARY_OP(float, usilu_f32, silu_fwd(x))
|
||||
UNARY_OP(double, usilu_f64, silu_fwd(x))
|
||||
UNARY_OP1(float, upowf_f32, powg(x, param))
|
||||
UNARY_OP1(double, upowf_f64, powg(x, param))
|
||||
UNARY_OP(float, usign_f32, sign_(x))
|
||||
UNARY_OP(double, usign_f64, sign_(x))
|
||||
|
Reference in New Issue
Block a user