diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index f39eedbb..65d91849 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -112,7 +112,8 @@ impl Tensor { } Op::Unary(_node, UnaryOp::Ceil) | Op::Unary(_node, UnaryOp::Floor) - | Op::Unary(_node, UnaryOp::Round) => nodes, + | Op::Unary(_node, UnaryOp::Round) + | Op::Unary(_node, UnaryOp::Sign) => nodes, Op::Reshape(node) | Op::UpsampleNearest1D { arg: node, .. } | Op::UpsampleNearest2D { arg: node, .. } @@ -488,7 +489,6 @@ impl Tensor { let sum_grad = grads.or_insert(arg)?; *sum_grad = sum_grad.add(&grad)?; } - Op::Cmp(_args, _) => {} Op::Reduce(arg, ReduceOp::Max, reduced_dims) => { let node = broadcast_back(arg, node, reduced_dims)?; let grad = broadcast_back(arg, &grad, reduced_dims)?; @@ -578,20 +578,18 @@ impl Tensor { let sum_grad = grads.or_insert(arg)?; *sum_grad = sum_grad.add(&arg_grad)? } - Op::Reduce(_, ReduceOp::ArgMin, _) => {} - Op::Reduce(_, ReduceOp::ArgMax, _) => {} + Op::Unary(_, UnaryOp::Floor) + | Op::Unary(_, UnaryOp::Round) + | Op::Reduce(_, ReduceOp::ArgMin, _) + | Op::Reduce(_, ReduceOp::ArgMax, _) + | Op::Unary(_, UnaryOp::Sign) + | Op::Cmp(_, _) => {} Op::Reshape(arg) => { let arg_grad = grad.reshape(arg.dims())?; let sum_grad = grads.or_insert(arg)?; *sum_grad = sum_grad.add(&arg_grad)? } Op::Unary(_, UnaryOp::Ceil) => Err(Error::BackwardNotSupported { op: "ceil" })?, - Op::Unary(_, UnaryOp::Floor) => { - Err(Error::BackwardNotSupported { op: "floor" })? - } - Op::Unary(_, UnaryOp::Round) => { - Err(Error::BackwardNotSupported { op: "round" })? - } Op::Unary(arg, UnaryOp::Gelu) => { let sum_grad = grads.or_insert(arg)?; let cube = arg.powf(3.)?; diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs index fa6973b4..0e058b45 100644 --- a/candle-core/src/metal_backend/mod.rs +++ b/candle-core/src/metal_backend/mod.rs @@ -497,6 +497,10 @@ impl BackendStorage for MetalStorage { ("utanh", DType::F16) => contiguous::tanh::HALF, ("utanh", DType::F32) => contiguous::tanh::FLOAT, ("utanh", DType::BF16) => contiguous::tanh::BFLOAT, + ("usign", DType::F16) => contiguous::sign::HALF, + ("usign", DType::F32) => contiguous::sign::FLOAT, + ("usign", DType::BF16) => contiguous::sign::BFLOAT, + ("usign", DType::I64) => contiguous::sign::I64, (name, dtype) => { crate::bail!("Metal contiguous unary {name} {dtype:?} not implemented") } diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index 776f5182..49ba44be 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -66,6 +66,7 @@ pub enum UnaryOp { Floor, Ceil, Round, + Sign, } #[derive(Clone)] @@ -254,6 +255,7 @@ pub(crate) struct Tanh; pub(crate) struct Floor; pub(crate) struct Ceil; pub(crate) struct Round; +pub(crate) struct Sign; macro_rules! bin_op { ($op:ident, $name: literal, $e: expr, $f32_vec: ident, $f64_vec: ident) => { @@ -925,3 +927,37 @@ impl std::ops::Deref for BackpropOp { &self.0 } } + +impl UnaryOpT for Sign { + const NAME: &'static str = "sign"; + const KERNEL: &'static str = "usign"; + const V: Self = Sign; + #[inline(always)] + fn bf16(v: bf16) -> bf16 { + bf16::from((v > bf16::ZERO) as i8) - bf16::from((v < bf16::ZERO) as i8) + } + #[inline(always)] + fn f16(v: f16) -> f16 { + f16::from((v > f16::ZERO) as i8) - f16::from((v < f16::ZERO) as i8) + } + #[inline(always)] + fn f32(v: f32) -> f32 { + f32::from(v > 0.) - f32::from(v < 0.) + } + #[inline(always)] + fn f64(v: f64) -> f64 { + f64::from(v > 0.) - f64::from(v < 0.) + } + #[inline(always)] + fn u8(v: u8) -> u8 { + u8::min(1, v) + } + #[inline(always)] + fn u32(v: u32) -> u32 { + u32::min(1, v) + } + #[inline(always)] + fn i64(v: i64) -> i64 { + (v > 0) as i64 - (v < 0) as i64 + } +} diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index b53b0419..a5a9dbb1 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -510,6 +510,7 @@ impl Tensor { unary_op!(ceil, Ceil); unary_op!(floor, Floor); unary_op!(round, Round); + unary_op!(sign, Sign); /// Round element of the input tensor to the nearest integer. /// diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index b3275804..78841779 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -151,6 +151,14 @@ fn unary_op(device: &Device) -> Result<()> { test_utils::to_vec1_round(&tensor.round_to(-2)?, 4)?, [3000.0, 300.] ); + let tensor = Tensor::new( + &[-1.01f32, -0.9, -0.1, 0.0, -0.0, 0.1, 0.9, 1.0, 1.1], + device, + )?; + assert_eq!( + tensor.sign()?.to_vec1::()?, + [-1., -1., -1., 0., 0., 1., 1., 1., 1.] + ); Ok(()) } diff --git a/candle-kernels/src/unary.cu b/candle-kernels/src/unary.cu index 13489897..a234304a 100644 --- a/candle-kernels/src/unary.cu +++ b/candle-kernels/src/unary.cu @@ -86,6 +86,11 @@ extern "C" __global__ void FN_NAME( \ } \ } \ +template +__device__ T sign_(T t) { + return static_cast(t > static_cast(0)) - static_cast(t < static_cast(0)); +} + #if __CUDA_ARCH__ >= 800 UNARY_OP(__nv_bfloat16, ucopy_bf16, x) @@ -110,6 +115,7 @@ UNARY_OP(__nv_bfloat16, urelu_bf16, relu_fwd(x)) UNARY_OP1(__nv_bfloat16, uelu_bf16, elu_fwd(x, param)) UNARY_OP(__nv_bfloat16, usilu_bf16, silu_fwd(x)) UNARY_OP1(__nv_bfloat16, upowf_bf16, powg(x, param)) +UNARY_OP(__nv_bfloat16, usign_bf16, sign_(x)) #endif #if __CUDA_ARCH__ >= 530 @@ -135,6 +141,7 @@ UNARY_OP(__half, urelu_f16, relu_fwd(x)) UNARY_OP1(__half, uelu_f16, elu_fwd(x, param)) UNARY_OP(__half, usilu_f16, silu_fwd(x)) UNARY_OP1(__half, upowf_f16, powg(x, param)) +UNARY_OP(__half, usign_f16, sign_(x)) #endif UNARY_OP(uint8_t, ucopy_u8, x) @@ -184,3 +191,5 @@ UNARY_OP(float, usilu_f32, silu_fwd(x)) UNARY_OP(double, usilu_f64, silu_fwd(x)) UNARY_OP1(float, upowf_f32, powg(x, param)) UNARY_OP1(double, upowf_f64, powg(x, param)) +UNARY_OP(float, usign_f32, sign_(x)) +UNARY_OP(double, usign_f64, sign_(x)) diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 140927e3..5af48fae 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -193,7 +193,7 @@ macro_rules! ops{ pub mod unary { ops!( cos, sin, exp, sqr, sqrt, neg, log, gelu, abs, ceil, floor, relu, round, erf, gelu_erf, - tanh, recip, silu + tanh, recip, silu, sign ); } pub mod binary { diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal index bdc13f9e..809522d7 100644 --- a/candle-metal-kernels/src/unary.metal +++ b/candle-metal-kernels/src/unary.metal @@ -145,6 +145,7 @@ UNARY_OP(erf) UNARY_OP(tanh) UNARY_OP(recip) UNARY_OP(relu) +UNARY_OP(sign) UNARY(id, float, copy_f32, copy_f32_strided) UNARY(id, half, copy_f16, copy_f16_strided) UNARY(id, uint8_t, copy_u8, copy_u8_strided) @@ -174,6 +175,7 @@ BFLOAT_UNARY_OP(erf) BFLOAT_UNARY_OP(tanh) BFLOAT_UNARY_OP(recip) BFLOAT_UNARY_OP(relu) +BFLOAT_UNARY_OP(sign) UNARY(id, bfloat, copy_bf16, copy_bf16_strided)