Add the rounding operators. (#1030)

* Add the rounding operators.

* Avoid tracking gradients for the rounding operations.

* Add some rounding tests.
This commit is contained in:
Laurent Mazare
2023-10-04 17:58:44 +01:00
committed by GitHub
parent 3349c89252
commit c18a856e76
6 changed files with 157 additions and 0 deletions

View File

@ -92,6 +92,9 @@ UNARY_OP(__nv_bfloat16, usin_bf16, sing(x))
UNARY_OP(__nv_bfloat16, ucos_bf16, cosg(x))
UNARY_OP(__nv_bfloat16, utanh_bf16, tanhg(x))
UNARY_OP(__nv_bfloat16, uerf_bf16, erfg(x))
UNARY_OP(__nv_bfloat16, uceil_bf16, ceilg(x))
UNARY_OP(__nv_bfloat16, ufloor_bf16, floorg(x))
UNARY_OP(__nv_bfloat16, uround_bf16, roundg(x))
UNARY_OP(__nv_bfloat16, unormcdf_bf16, normcdfg(x))
UNARY_OP(__nv_bfloat16, uabs_bf16, absg(x))
UNARY_OP(__nv_bfloat16, usqr_bf16, x*x)
@ -113,6 +116,9 @@ UNARY_OP(__half, usin_f16, sing(x))
UNARY_OP(__half, ucos_f16, cosg(x))
UNARY_OP(__half, utanh_f16, tanhg(x))
UNARY_OP(__half, uerf_f16, erfg(x))
UNARY_OP(__half, uceil_f16, ceilg(x))
UNARY_OP(__half, ufloor_f16, floorg(x))
UNARY_OP(__half, uround_f16, roundg(x))
UNARY_OP(__half, unormcdf_f16, normcdfg(x))
UNARY_OP(__half, uabs_f16, absg(x))
UNARY_OP(__half, usqr_f16, x*x)
@ -145,6 +151,12 @@ UNARY_OP(float, utanh_f32, tanhg(x))
UNARY_OP(double, utanh_f64, tanhg(x))
UNARY_OP(float, uerf_f32, erfg(x))
UNARY_OP(double, uerf_f64, erfg(x))
UNARY_OP(float, uceil_f32, ceilg(x))
UNARY_OP(double, uceil_f64, ceilg(x))
UNARY_OP(float, ufloor_f32, floorg(x))
UNARY_OP(double, ufloor_f64, floorg(x))
UNARY_OP(float, uround_f32, roundg(x))
UNARY_OP(double, uround_f64, roundg(x))
UNARY_OP(float, unormcdf_f32, normcdfg(x))
UNARY_OP(double, unormcdf_f64, normcdfg(x))
UNARY_OP(float, uabs_f32, absg(x))