Add support for "sign" on tensors (#2012)

* add the sign unary operator

* remove uneeded import

* remove uneeded import

* undo formatting

* undo formatting

* remove unnecessary redefintion

* allow gradient to flow through for sign and round

* fix cpu ops to ensure that negzero and positive zero are handled properly

* clippy fixes

* Properly avoid gradient tracking.

* Use a branchless version.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
This commit is contained in:
Thomas Santerre
2024-04-04 16:32:47 -04:00
committed by GitHub
parent e6a5b82ba6
commit c5626b8271
8 changed files with 69 additions and 11 deletions

View File

@ -86,6 +86,11 @@ extern "C" __global__ void FN_NAME( \
} \
} \
template<typename T>
__device__ T sign_(T t) {
return static_cast<T>(t > static_cast<T>(0)) - static_cast<T>(t < static_cast<T>(0));
}
#if __CUDA_ARCH__ >= 800
UNARY_OP(__nv_bfloat16, ucopy_bf16, x)
@ -110,6 +115,7 @@ UNARY_OP(__nv_bfloat16, urelu_bf16, relu_fwd(x))
UNARY_OP1(__nv_bfloat16, uelu_bf16, elu_fwd(x, param))
UNARY_OP(__nv_bfloat16, usilu_bf16, silu_fwd(x))
UNARY_OP1(__nv_bfloat16, upowf_bf16, powg(x, param))
UNARY_OP(__nv_bfloat16, usign_bf16, sign_(x))
#endif
#if __CUDA_ARCH__ >= 530
@ -135,6 +141,7 @@ UNARY_OP(__half, urelu_f16, relu_fwd(x))
UNARY_OP1(__half, uelu_f16, elu_fwd(x, param))
UNARY_OP(__half, usilu_f16, silu_fwd(x))
UNARY_OP1(__half, upowf_f16, powg(x, param))
UNARY_OP(__half, usign_f16, sign_(x))
#endif
UNARY_OP(uint8_t, ucopy_u8, x)
@ -184,3 +191,5 @@ UNARY_OP(float, usilu_f32, silu_fwd(x))
UNARY_OP(double, usilu_f64, silu_fwd(x))
UNARY_OP1(float, upowf_f32, powg(x, param))
UNARY_OP1(double, upowf_f64, powg(x, param))
UNARY_OP(float, usign_f32, sign_(x))
UNARY_OP(double, usign_f64, sign_(x))