Add an erf based gelu op (#900)

* Erf based gelu.

* Add the erf backed gelu.

* Test the new gelu op (which is not gelu_new).
This commit is contained in:
Laurent Mazare
2023-09-19 19:54:28 +01:00
committed by GitHub
parent 34f2ecbc3b
commit d7e48234d4
8 changed files with 851 additions and 1 deletions

View File

@ -58,6 +58,7 @@ pub enum UnaryOp {
Sqr,
Sqrt,
Gelu,
GeluErf,
Relu,
Tanh,
}
@ -325,6 +326,7 @@ pub(crate) struct Recip;
pub(crate) struct Sqr;
pub(crate) struct Sqrt;
pub(crate) struct Gelu;
pub(crate) struct GeluErf;
pub(crate) struct Relu;
pub(crate) struct Tanh;
@ -621,6 +623,40 @@ impl UnaryOpT for Gelu {
}
}
impl UnaryOpT for GeluErf {
const NAME: &'static str = "gelu_erf";
const KERNEL: &'static str = "ugelu_erf";
const V: Self = GeluErf;
#[inline(always)]
fn bf16(v: bf16) -> bf16 {
bf16::from_f64(Self::f64(v.to_f64()))
}
#[inline(always)]
fn f16(v: f16) -> f16 {
f16::from_f64(Self::f64(v.to_f64()))
}
#[inline(always)]
fn f32(v: f32) -> f32 {
Self::f64(v as f64) as f32
}
#[inline(always)]
fn f64(v: f64) -> f64 {
(crate::cpu::erf::erf(v / 2f64.sqrt()) + 1.) * 0.5 * v
}
#[inline(always)]
fn u8(_: u8) -> u8 {
0
}
#[inline(always)]
fn u32(_: u32) -> u32 {
0
}
#[inline(always)]
fn i64(_: i64) -> i64 {
0
}
}
impl UnaryOpT for Relu {
const NAME: &'static str = "relu";
const KERNEL: &'static str = "urelu";