Add the erf function. (#917)

This commit is contained in:
Laurent Mazare
2023-09-21 06:19:10 +01:00
committed by GitHub
parent ab1d40ea97
commit 7b26e513f1
4 changed files with 45 additions and 0 deletions

View File

@ -442,6 +442,7 @@ impl Tensor {
*sum_grad = sum_grad.add(&arg_grad)?
}
Op::Unary(_, UnaryOp::Gelu) => Err(Error::BackwardNotSupported { op: "gelu" })?,
Op::Unary(_, UnaryOp::Erf) => Err(Error::BackwardNotSupported { op: "erf" })?,
Op::Unary(_, UnaryOp::GeluErf) => {
Err(Error::BackwardNotSupported { op: "gelu-erf" })?
}

View File

@ -59,6 +59,7 @@ pub enum UnaryOp {
Sqrt,
Gelu,
GeluErf,
Erf,
Relu,
Tanh,
}
@ -327,6 +328,7 @@ pub(crate) struct Sqr;
pub(crate) struct Sqrt;
pub(crate) struct Gelu;
pub(crate) struct GeluErf;
pub(crate) struct Erf;
pub(crate) struct Relu;
pub(crate) struct Tanh;
@ -623,6 +625,40 @@ impl UnaryOpT for Gelu {
}
}
impl UnaryOpT for Erf {
const NAME: &'static str = "erf";
const KERNEL: &'static str = "uerf";
const V: Self = Erf;
#[inline(always)]
fn bf16(v: bf16) -> bf16 {
bf16::from_f64(Self::f64(v.to_f64()))
}
#[inline(always)]
fn f16(v: f16) -> f16 {
f16::from_f64(Self::f64(v.to_f64()))
}
#[inline(always)]
fn f32(v: f32) -> f32 {
Self::f64(v as f64) as f32
}
#[inline(always)]
fn f64(v: f64) -> f64 {
crate::cpu::erf::erf(v)
}
#[inline(always)]
fn u8(_: u8) -> u8 {
0
}
#[inline(always)]
fn u32(_: u32) -> u32 {
0
}
#[inline(always)]
fn i64(_: i64) -> i64 {
0
}
}
impl UnaryOpT for GeluErf {
const NAME: &'static str = "gelu_erf";
const KERNEL: &'static str = "ugelu_erf";

View File

@ -490,6 +490,7 @@ impl Tensor {
unary_op!(sqrt, Sqrt);
unary_op!(gelu, Gelu);
unary_op!(gelu_erf, GeluErf);
unary_op!(erf, Erf);
unary_op!(relu, Relu);
/// Retrieves the single scalar value hold in the tensor. If the tensor contains multiple

View File

@ -61,6 +61,13 @@ fn unary_op(device: &Device) -> Result<()> {
[2.6906, -0.0647, -0.1091, 1.7353, 2.7928]
]
);
assert_eq!(
test_utils::to_vec2_round(&tensor.erf()?, 4)?,
[
[-1.0, 0.8427, 1.0, -0.1125, 0.5205],
[0.9999, -0.9891, -0.3079, 0.9891, 0.9999]
]
);
Ok(())
}