mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Add support for "sign" on tensors (#2012)
* add the sign unary operator * remove uneeded import * remove uneeded import * undo formatting * undo formatting * remove unnecessary redefintion * allow gradient to flow through for sign and round * fix cpu ops to ensure that negzero and positive zero are handled properly * clippy fixes * Properly avoid gradient tracking. * Use a branchless version. --------- Co-authored-by: laurent <laurent.mazare@gmail.com>
This commit is contained in:
@ -193,7 +193,7 @@ macro_rules! ops{
|
||||
pub mod unary {
|
||||
ops!(
|
||||
cos, sin, exp, sqr, sqrt, neg, log, gelu, abs, ceil, floor, relu, round, erf, gelu_erf,
|
||||
tanh, recip, silu
|
||||
tanh, recip, silu, sign
|
||||
);
|
||||
}
|
||||
pub mod binary {
|
||||
|
@ -145,6 +145,7 @@ UNARY_OP(erf)
|
||||
UNARY_OP(tanh)
|
||||
UNARY_OP(recip)
|
||||
UNARY_OP(relu)
|
||||
UNARY_OP(sign)
|
||||
UNARY(id, float, copy_f32, copy_f32_strided)
|
||||
UNARY(id, half, copy_f16, copy_f16_strided)
|
||||
UNARY(id, uint8_t, copy_u8, copy_u8_strided)
|
||||
@ -174,6 +175,7 @@ BFLOAT_UNARY_OP(erf)
|
||||
BFLOAT_UNARY_OP(tanh)
|
||||
BFLOAT_UNARY_OP(recip)
|
||||
BFLOAT_UNARY_OP(relu)
|
||||
BFLOAT_UNARY_OP(sign)
|
||||
|
||||
UNARY(id, bfloat, copy_bf16, copy_bf16_strided)
|
||||
|
||||
|
Reference in New Issue
Block a user