Add support for "sign" on tensors (#2012)

* add the sign unary operator

* remove uneeded import

* remove uneeded import

* undo formatting

* undo formatting

* remove unnecessary redefintion

* allow gradient to flow through for sign and round

* fix cpu ops to ensure that negzero and positive zero are handled properly

* clippy fixes

* Properly avoid gradient tracking.

* Use a branchless version.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
This commit is contained in:
Thomas Santerre
2024-04-04 16:32:47 -04:00
committed by GitHub
parent e6a5b82ba6
commit c5626b8271
8 changed files with 69 additions and 11 deletions

View File

@ -193,7 +193,7 @@ macro_rules! ops{
pub mod unary {
ops!(
cos, sin, exp, sqr, sqrt, neg, log, gelu, abs, ceil, floor, relu, round, erf, gelu_erf,
tanh, recip, silu
tanh, recip, silu, sign
);
}
pub mod binary {

View File

@ -145,6 +145,7 @@ UNARY_OP(erf)
UNARY_OP(tanh)
UNARY_OP(recip)
UNARY_OP(relu)
UNARY_OP(sign)
UNARY(id, float, copy_f32, copy_f32_strided)
UNARY(id, half, copy_f16, copy_f16_strided)
UNARY(id, uint8_t, copy_u8, copy_u8_strided)
@ -174,6 +175,7 @@ BFLOAT_UNARY_OP(erf)
BFLOAT_UNARY_OP(tanh)
BFLOAT_UNARY_OP(recip)
BFLOAT_UNARY_OP(relu)
BFLOAT_UNARY_OP(sign)
UNARY(id, bfloat, copy_bf16, copy_bf16_strided)