diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index b7f99f11..e1168c2e 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -536,7 +536,6 @@ unary_op!(Log, "log", v, v.ln(), vs_ln, vd_ln); unary_op!(Sin, "sin", v, v.sin(), vs_sin, vd_sin); unary_op!(Cos, "cos", v, v.cos(), vs_cos, vd_cos); unary_op!(Tanh, "tanh", v, v.tanh(), vs_tanh, vd_tanh); -unary_op!(Abs, "abs", v, v.abs()); unary_op!(Neg, "neg", v, -v); unary_op!(Recip, "recip", v, v.recip()); unary_op!(Sqr, "sqr", v, v * v, vs_sqr, vd_sqr); @@ -666,6 +665,40 @@ impl UnaryOpT for Erf { } } +impl UnaryOpT for Abs { + const NAME: &'static str = "abs"; + const KERNEL: &'static str = "uabs"; + const V: Self = Abs; + #[inline(always)] + fn bf16(v: bf16) -> bf16 { + v.abs() + } + #[inline(always)] + fn f16(v: f16) -> f16 { + v.abs() + } + #[inline(always)] + fn f32(v: f32) -> f32 { + v.abs() + } + #[inline(always)] + fn f64(v: f64) -> f64 { + v.abs() + } + #[inline(always)] + fn u8(v: u8) -> u8 { + v + } + #[inline(always)] + fn u32(v: u32) -> u32 { + v + } + #[inline(always)] + fn i64(v: i64) -> i64 { + v.abs() + } +} + impl UnaryOpT for Ceil { const NAME: &'static str = "ceil"; const KERNEL: &'static str = "uceil"; diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index ae1bd058..899efcf3 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -1089,3 +1089,11 @@ fn pad_with_same() -> Result<()> { ); Ok(()) } + +#[test] +fn i64_abs() -> Result<()> { + let t = Tensor::new(&[-42i64, 1337], &Device::Cpu)?; + let t = t.abs()?; + assert_eq!(t.to_vec1::()?, [42, 1337]); + Ok(()) +}