Add support for "sign" on tensors (#2012)

* add the sign unary operator

* remove uneeded import

* remove uneeded import

* undo formatting

* undo formatting

* remove unnecessary redefintion

* allow gradient to flow through for sign and round

* fix cpu ops to ensure that negzero and positive zero are handled properly

* clippy fixes

* Properly avoid gradient tracking.

* Use a branchless version.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
This commit is contained in:
Thomas Santerre
2024-04-04 16:32:47 -04:00
committed by GitHub
parent e6a5b82ba6
commit c5626b8271
8 changed files with 69 additions and 11 deletions

View File

@ -193,7 +193,7 @@ macro_rules! ops{
pub mod unary {
ops!(
cos, sin, exp, sqr, sqrt, neg, log, gelu, abs, ceil, floor, relu, round, erf, gelu_erf,
tanh, recip, silu
tanh, recip, silu, sign
);
}
pub mod binary {