From 7b26e513f15a0c7cd55ccfe48525bda1079427ce Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Thu, 21 Sep 2023 06:19:10 +0100 Subject: [PATCH] Add the erf function. (#917) --- candle-core/src/backprop.rs | 1 + candle-core/src/op.rs | 36 +++++++++++++++++++++++++++++++ candle-core/src/tensor.rs | 1 + candle-core/tests/tensor_tests.rs | 7 ++++++ 4 files changed, 45 insertions(+) diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index 3e2ae1ed..a2548198 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -442,6 +442,7 @@ impl Tensor { *sum_grad = sum_grad.add(&arg_grad)? } Op::Unary(_, UnaryOp::Gelu) => Err(Error::BackwardNotSupported { op: "gelu" })?, + Op::Unary(_, UnaryOp::Erf) => Err(Error::BackwardNotSupported { op: "erf" })?, Op::Unary(_, UnaryOp::GeluErf) => { Err(Error::BackwardNotSupported { op: "gelu-erf" })? } diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index 26dc6609..4882a205 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -59,6 +59,7 @@ pub enum UnaryOp { Sqrt, Gelu, GeluErf, + Erf, Relu, Tanh, } @@ -327,6 +328,7 @@ pub(crate) struct Sqr; pub(crate) struct Sqrt; pub(crate) struct Gelu; pub(crate) struct GeluErf; +pub(crate) struct Erf; pub(crate) struct Relu; pub(crate) struct Tanh; @@ -623,6 +625,40 @@ impl UnaryOpT for Gelu { } } +impl UnaryOpT for Erf { + const NAME: &'static str = "erf"; + const KERNEL: &'static str = "uerf"; + const V: Self = Erf; + #[inline(always)] + fn bf16(v: bf16) -> bf16 { + bf16::from_f64(Self::f64(v.to_f64())) + } + #[inline(always)] + fn f16(v: f16) -> f16 { + f16::from_f64(Self::f64(v.to_f64())) + } + #[inline(always)] + fn f32(v: f32) -> f32 { + Self::f64(v as f64) as f32 + } + #[inline(always)] + fn f64(v: f64) -> f64 { + crate::cpu::erf::erf(v) + } + #[inline(always)] + fn u8(_: u8) -> u8 { + 0 + } + #[inline(always)] + fn u32(_: u32) -> u32 { + 0 + } + #[inline(always)] + fn i64(_: i64) -> i64 { + 0 + } +} + impl UnaryOpT for GeluErf { const NAME: &'static str = "gelu_erf"; const KERNEL: &'static str = "ugelu_erf"; diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index eafd7002..9dccf2b5 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -490,6 +490,7 @@ impl Tensor { unary_op!(sqrt, Sqrt); unary_op!(gelu, Gelu); unary_op!(gelu_erf, GeluErf); + unary_op!(erf, Erf); unary_op!(relu, Relu); /// Retrieves the single scalar value hold in the tensor. If the tensor contains multiple diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index 408f4c55..edd0bd79 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -61,6 +61,13 @@ fn unary_op(device: &Device) -> Result<()> { [2.6906, -0.0647, -0.1091, 1.7353, 2.7928] ] ); + assert_eq!( + test_utils::to_vec2_round(&tensor.erf()?, 4)?, + [ + [-1.0, 0.8427, 1.0, -0.1125, 0.5205], + [0.9999, -0.9891, -0.3079, 0.9891, 0.9999] + ] + ); Ok(()) }